Unverified Commit b658be6f authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Support tool call parser in streaming (#11160)

parent 5e786cca
...@@ -8,9 +8,9 @@ ...@@ -8,9 +8,9 @@
//! - Different model formats (JSON, Mistral, Qwen, Pythonic, etc.) //! - Different model formats (JSON, Mistral, Qwen, Pythonic, etc.)
use criterion::{black_box, criterion_group, BenchmarkId, Criterion, Throughput}; use criterion::{black_box, criterion_group, BenchmarkId, Criterion, Throughput};
use sglang_router_rs::tool_parser::{ use serde_json::json;
registry::ParserRegistry, state::ParseState, types::StreamResult, use sglang_router_rs::protocols::spec::{Function, Tool};
}; use sglang_router_rs::tool_parser::{JsonParser, ToolParser, ToolParserFactory};
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
...@@ -108,6 +108,40 @@ const STEP3_FORMAT: &str = r#"<step.tML version="0.1"> ...@@ -108,6 +108,40 @@ const STEP3_FORMAT: &str = r#"<step.tML version="0.1">
const GPT_OSS_FORMAT: &str = r#"<Channel.vector_search>{"collection": "technical_documentation", "query_embedding": [0.0234, -0.1456, 0.0891, 0.2341, -0.0567, 0.1234, 0.0456, -0.0789, 0.1567, 0.0234, -0.1123, 0.0678, 0.2345, -0.0456, 0.0891, 0.1234, -0.0567, 0.0789, 0.1456, -0.0234, 0.0891, 0.1567, -0.0678, 0.0345, 0.1234, -0.0456, 0.0789, 0.1891, -0.0234, 0.0567, 0.1345, -0.0891], "top_k": 10, "similarity_metric": "cosine", "filters": {"language": "en", "last_updated": {"$gte": "2023-01-01"}, "categories": {"$in": ["api", "sdk", "integration"]}}, "include_metadata": true, "rerank_with_cross_encoder": true}</Channel.vector_search>"#; const GPT_OSS_FORMAT: &str = r#"<Channel.vector_search>{"collection": "technical_documentation", "query_embedding": [0.0234, -0.1456, 0.0891, 0.2341, -0.0567, 0.1234, 0.0456, -0.0789, 0.1567, 0.0234, -0.1123, 0.0678, 0.2345, -0.0456, 0.0891, 0.1234, -0.0567, 0.0789, 0.1456, -0.0234, 0.0891, 0.1567, -0.0678, 0.0345, 0.1234, -0.0456, 0.0789, 0.1891, -0.0234, 0.0567, 0.1345, -0.0891], "top_k": 10, "similarity_metric": "cosine", "filters": {"language": "en", "last_updated": {"$gte": "2023-01-01"}, "categories": {"$in": ["api", "sdk", "integration"]}}, "include_metadata": true, "rerank_with_cross_encoder": true}</Channel.vector_search>"#;
// Create test tools for parsers that need them
fn create_test_tools() -> Vec<Tool> {
vec![
Tool {
tool_type: "function".to_string(),
function: Function {
name: "search".to_string(),
description: Some("Search for information".to_string()),
parameters: json!({
"type": "object",
"properties": {
"query": {"type": "string"},
"limit": {"type": "number"}
}
}),
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "code_interpreter".to_string(),
description: Some("Execute code".to_string()),
parameters: json!({
"type": "object",
"properties": {
"language": {"type": "string"},
"code": {"type": "string"}
}
}),
},
},
]
}
// Large test data for stress testing // Large test data for stress testing
fn generate_large_json(num_tools: usize) -> String { fn generate_large_json(num_tools: usize) -> String {
let mut tools = Vec::new(); let mut tools = Vec::new();
...@@ -141,7 +175,7 @@ fn bench_registry_creation(c: &mut Criterion) { ...@@ -141,7 +175,7 @@ fn bench_registry_creation(c: &mut Criterion) {
b.iter_custom(|iters| { b.iter_custom(|iters| {
let start = Instant::now(); let start = Instant::now();
for _ in 0..iters { for _ in 0..iters {
let registry = black_box(ParserRegistry::new()); let registry = black_box(ToolParserFactory::new());
// Force evaluation to prevent optimization // Force evaluation to prevent optimization
black_box(registry.list_parsers()); black_box(registry.list_parsers());
} }
...@@ -168,7 +202,7 @@ fn bench_registry_creation(c: &mut Criterion) { ...@@ -168,7 +202,7 @@ fn bench_registry_creation(c: &mut Criterion) {
} }
fn bench_parser_lookup(c: &mut Criterion) { fn bench_parser_lookup(c: &mut Criterion) {
let registry = Arc::new(ParserRegistry::new()); let registry = Arc::new(ToolParserFactory::new());
let models = vec![ let models = vec![
"gpt-4", "gpt-4",
"mistral-large", "mistral-large",
...@@ -227,7 +261,7 @@ fn bench_parser_lookup(c: &mut Criterion) { ...@@ -227,7 +261,7 @@ fn bench_parser_lookup(c: &mut Criterion) {
fn bench_complete_parsing(c: &mut Criterion) { fn bench_complete_parsing(c: &mut Criterion) {
let rt = Runtime::new().unwrap(); let rt = Runtime::new().unwrap();
let registry = Arc::new(ParserRegistry::new()); let registry = Arc::new(ToolParserFactory::new());
let test_cases = vec![ let test_cases = vec![
("json_simple", "json", JSON_SIMPLE), ("json_simple", "json", JSON_SIMPLE),
...@@ -295,7 +329,6 @@ fn bench_complete_parsing(c: &mut Criterion) { ...@@ -295,7 +329,6 @@ fn bench_complete_parsing(c: &mut Criterion) {
fn bench_streaming_parsing(c: &mut Criterion) { fn bench_streaming_parsing(c: &mut Criterion) {
let rt = Runtime::new().unwrap(); let rt = Runtime::new().unwrap();
let registry = Arc::new(ParserRegistry::new());
// Streaming test with chunked input // Streaming test with chunked input
let chunks = vec![ let chunks = vec![
...@@ -315,24 +348,21 @@ fn bench_streaming_parsing(c: &mut Criterion) { ...@@ -315,24 +348,21 @@ fn bench_streaming_parsing(c: &mut Criterion) {
let printed = Arc::new(AtomicBool::new(false)); let printed = Arc::new(AtomicBool::new(false));
group.bench_function("json_streaming", |b| { group.bench_function("json_streaming", |b| {
let printed_clone = printed.clone(); let printed_clone = printed.clone();
let registry = registry.clone();
let rt = rt.handle().clone(); let rt = rt.handle().clone();
b.iter_custom(|iters| { b.iter_custom(|iters| {
let parser = registry.get_parser("json").expect("Parser not found"); let tools = create_test_tools();
let start = Instant::now(); let start = Instant::now();
for _ in 0..iters { for _ in 0..iters {
let parser = parser.clone(); let mut parser = JsonParser::new();
let mut state = ParseState::new();
let mut complete_tools = Vec::new(); let mut complete_tools = Vec::new();
rt.block_on(async { rt.block_on(async {
for chunk in &chunks { for chunk in &chunks {
if let StreamResult::ToolComplete(tool) = let result = parser.parse_incremental(chunk, &tools).await.unwrap();
parser.parse_incremental(chunk, &mut state).await.unwrap() if !result.calls.is_empty() {
{ complete_tools.extend(result.calls);
complete_tools.push(tool);
} }
} }
}); });
...@@ -368,7 +398,7 @@ fn bench_streaming_parsing(c: &mut Criterion) { ...@@ -368,7 +398,7 @@ fn bench_streaming_parsing(c: &mut Criterion) {
fn bench_concurrent_parsing(c: &mut Criterion) { fn bench_concurrent_parsing(c: &mut Criterion) {
let rt = Runtime::new().unwrap(); let rt = Runtime::new().unwrap();
let registry = Arc::new(ParserRegistry::new()); let registry = Arc::new(ToolParserFactory::new());
let parser = registry.get_parser("json").expect("Parser not found"); let parser = registry.get_parser("json").expect("Parser not found");
let thread_counts = vec![1, 2, 4, 8, 16, 32]; let thread_counts = vec![1, 2, 4, 8, 16, 32];
...@@ -456,7 +486,7 @@ fn bench_concurrent_parsing(c: &mut Criterion) { ...@@ -456,7 +486,7 @@ fn bench_concurrent_parsing(c: &mut Criterion) {
fn bench_large_payloads(c: &mut Criterion) { fn bench_large_payloads(c: &mut Criterion) {
let rt = Runtime::new().unwrap(); let rt = Runtime::new().unwrap();
let registry = Arc::new(ParserRegistry::new()); let registry = Arc::new(ToolParserFactory::new());
let parser = registry.get_parser("json").expect("Parser not found"); let parser = registry.get_parser("json").expect("Parser not found");
let sizes = vec![1, 10, 50, 100, 500]; let sizes = vec![1, 10, 50, 100, 500];
...@@ -526,7 +556,7 @@ fn bench_parser_reuse(c: &mut Criterion) { ...@@ -526,7 +556,7 @@ fn bench_parser_reuse(c: &mut Criterion) {
b.iter_custom(|iters| { b.iter_custom(|iters| {
let start = Instant::now(); let start = Instant::now();
for _ in 0..iters { for _ in 0..iters {
let registry = ParserRegistry::new(); let registry = ToolParserFactory::new();
let parser = registry.get_parser("json").unwrap(); let parser = registry.get_parser("json").unwrap();
let result = rt.block_on(async { parser.parse_complete(JSON_SIMPLE).await }); let result = rt.block_on(async { parser.parse_complete(JSON_SIMPLE).await });
black_box(result.unwrap()); black_box(result.unwrap());
...@@ -552,7 +582,7 @@ fn bench_parser_reuse(c: &mut Criterion) { ...@@ -552,7 +582,7 @@ fn bench_parser_reuse(c: &mut Criterion) {
// Benchmark reusing registry // Benchmark reusing registry
let printed_reuse = Arc::new(AtomicBool::new(false)); let printed_reuse = Arc::new(AtomicBool::new(false));
let shared_registry = Arc::new(ParserRegistry::new()); let shared_registry = Arc::new(ToolParserFactory::new());
group.bench_function("reuse_registry", |b| { group.bench_function("reuse_registry", |b| {
let printed_clone = printed_reuse.clone(); let printed_clone = printed_reuse.clone();
...@@ -627,7 +657,7 @@ fn bench_parser_reuse(c: &mut Criterion) { ...@@ -627,7 +657,7 @@ fn bench_parser_reuse(c: &mut Criterion) {
fn bench_latency_distribution(c: &mut Criterion) { fn bench_latency_distribution(c: &mut Criterion) {
let rt = Runtime::new().unwrap(); let rt = Runtime::new().unwrap();
let registry = Arc::new(ParserRegistry::new()); let registry = Arc::new(ToolParserFactory::new());
let test_cases = vec![ let test_cases = vec![
("json", JSON_SIMPLE), ("json", JSON_SIMPLE),
......
...@@ -7,7 +7,7 @@ use crate::policies::PolicyRegistry; ...@@ -7,7 +7,7 @@ use crate::policies::PolicyRegistry;
use crate::reasoning_parser::ParserFactory; use crate::reasoning_parser::ParserFactory;
use crate::routers::RouterTrait; use crate::routers::RouterTrait;
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ParserRegistry; use crate::tool_parser::ToolParserFactory;
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{
body::Body, body::Body,
...@@ -25,7 +25,7 @@ pub struct GrpcPDRouter { ...@@ -25,7 +25,7 @@ pub struct GrpcPDRouter {
policy_registry: Arc<PolicyRegistry>, policy_registry: Arc<PolicyRegistry>,
tokenizer: Arc<dyn Tokenizer>, tokenizer: Arc<dyn Tokenizer>,
reasoning_parser_factory: ParserFactory, reasoning_parser_factory: ParserFactory,
tool_parser_registry: &'static ParserRegistry, tool_parser_factory: ToolParserFactory,
dp_aware: bool, dp_aware: bool,
api_key: Option<String>, api_key: Option<String>,
...@@ -50,9 +50,11 @@ impl GrpcPDRouter { ...@@ -50,9 +50,11 @@ impl GrpcPDRouter {
.as_ref() .as_ref()
.ok_or_else(|| "gRPC PD router requires reasoning parser factory".to_string())? .ok_or_else(|| "gRPC PD router requires reasoning parser factory".to_string())?
.clone(); .clone();
let tool_parser_registry = ctx let tool_parser_factory = ctx
.tool_parser_registry .tool_parser_factory
.ok_or_else(|| "gRPC PD router requires tool parser registry".to_string())?; .as_ref()
.ok_or_else(|| "gRPC PD router requires tool parser factory".to_string())?
.clone();
// Get prefill and decode workers from registry - they should have been created by WorkerManager // Get prefill and decode workers from registry - they should have been created by WorkerManager
let prefill_workers = worker_registry.get_workers_filtered( let prefill_workers = worker_registry.get_workers_filtered(
...@@ -86,7 +88,7 @@ impl GrpcPDRouter { ...@@ -86,7 +88,7 @@ impl GrpcPDRouter {
policy_registry, policy_registry,
tokenizer, tokenizer,
reasoning_parser_factory, reasoning_parser_factory,
tool_parser_registry, tool_parser_factory,
dp_aware: ctx.router_config.dp_aware, dp_aware: ctx.router_config.dp_aware,
api_key: ctx.router_config.api_key.clone(), api_key: ctx.router_config.api_key.clone(),
retry_config: ctx.router_config.effective_retry_config(), retry_config: ctx.router_config.effective_retry_config(),
......
...@@ -34,7 +34,7 @@ use crate::tokenizer::stop::{ ...@@ -34,7 +34,7 @@ use crate::tokenizer::stop::{
}; };
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
use crate::tokenizer::HuggingFaceTokenizer; use crate::tokenizer::HuggingFaceTokenizer;
use crate::tool_parser::ParserRegistry; use crate::tool_parser::ToolParserFactory;
use proto::generate_response::Response::{Chunk, Complete, Error}; use proto::generate_response::Response::{Chunk, Complete, Error};
use serde_json::{json, Map, Value}; use serde_json::{json, Map, Value};
use std::time::{Instant, SystemTime, UNIX_EPOCH}; use std::time::{Instant, SystemTime, UNIX_EPOCH};
...@@ -56,7 +56,7 @@ pub struct GrpcRouter { ...@@ -56,7 +56,7 @@ pub struct GrpcRouter {
policy_registry: Arc<PolicyRegistry>, policy_registry: Arc<PolicyRegistry>,
tokenizer: Arc<dyn Tokenizer>, tokenizer: Arc<dyn Tokenizer>,
reasoning_parser_factory: ParserFactory, reasoning_parser_factory: ParserFactory,
tool_parser_registry: &'static ParserRegistry, tool_parser_factory: ToolParserFactory,
dp_aware: bool, dp_aware: bool,
api_key: Option<String>, api_key: Option<String>,
retry_config: RetryConfig, retry_config: RetryConfig,
...@@ -76,9 +76,11 @@ impl GrpcRouter { ...@@ -76,9 +76,11 @@ impl GrpcRouter {
.as_ref() .as_ref()
.ok_or_else(|| "gRPC router requires reasoning parser factory".to_string())? .ok_or_else(|| "gRPC router requires reasoning parser factory".to_string())?
.clone(); .clone();
let tool_parser_registry = ctx let tool_parser_factory = ctx
.tool_parser_registry .tool_parser_factory
.ok_or_else(|| "gRPC router requires tool parser registry".to_string())?; .as_ref()
.ok_or_else(|| "gRPC router requires tool parser factory".to_string())?
.clone();
let worker_registry = ctx.worker_registry.clone(); let worker_registry = ctx.worker_registry.clone();
let policy_registry = ctx.policy_registry.clone(); let policy_registry = ctx.policy_registry.clone();
...@@ -98,7 +100,7 @@ impl GrpcRouter { ...@@ -98,7 +100,7 @@ impl GrpcRouter {
policy_registry, policy_registry,
tokenizer, tokenizer,
reasoning_parser_factory, reasoning_parser_factory,
tool_parser_registry, tool_parser_factory,
dp_aware: ctx.router_config.dp_aware, dp_aware: ctx.router_config.dp_aware,
api_key: ctx.router_config.api_key.clone(), api_key: ctx.router_config.api_key.clone(),
retry_config: ctx.router_config.effective_retry_config(), retry_config: ctx.router_config.effective_retry_config(),
...@@ -779,15 +781,28 @@ impl GrpcRouter { ...@@ -779,15 +781,28 @@ impl GrpcRouter {
processed_text: &str, processed_text: &str,
model: &str, model: &str,
) -> (Option<Vec<ToolCall>>, String) { ) -> (Option<Vec<ToolCall>>, String) {
let Some(parser) = self.tool_parser_registry.get_parser(model) else { // Get pooled parser for this model
return (None, processed_text.to_string()); let pooled_parser = self.tool_parser_factory.get_pooled(model);
// Check format detection first
let can_parse = {
let parser = pooled_parser.lock().await;
parser.detect_format(processed_text)
// Lock is dropped here
}; };
if !parser.detect_format(processed_text) { if !can_parse {
return (None, processed_text.to_string()); return (None, processed_text.to_string());
} }
match parser.parse_complete(processed_text).await { // Lock again for async parsing
let result = {
let parser = pooled_parser.lock().await;
parser.parse_complete(processed_text).await
// Lock is dropped here
};
match result {
Ok((normal_text, parsed_tool_calls)) => { Ok((normal_text, parsed_tool_calls)) => {
if parsed_tool_calls.is_empty() { if parsed_tool_calls.is_empty() {
return (None, normal_text); return (None, normal_text);
......
...@@ -19,7 +19,7 @@ use crate::{ ...@@ -19,7 +19,7 @@ use crate::{
routers::{router_manager::RouterManager, RouterTrait}, routers::{router_manager::RouterManager, RouterTrait},
service_discovery::{start_service_discovery, ServiceDiscoveryConfig}, service_discovery::{start_service_discovery, ServiceDiscoveryConfig},
tokenizer::{factory as tokenizer_factory, traits::Tokenizer}, tokenizer::{factory as tokenizer_factory, traits::Tokenizer},
tool_parser::ParserRegistry, tool_parser::ToolParserFactory,
}; };
use axum::{ use axum::{
extract::{Path, Query, Request, State}, extract::{Path, Query, Request, State},
...@@ -46,7 +46,7 @@ pub struct AppContext { ...@@ -46,7 +46,7 @@ pub struct AppContext {
pub rate_limiter: Arc<TokenBucket>, pub rate_limiter: Arc<TokenBucket>,
pub tokenizer: Option<Arc<dyn Tokenizer>>, pub tokenizer: Option<Arc<dyn Tokenizer>>,
pub reasoning_parser_factory: Option<ParserFactory>, pub reasoning_parser_factory: Option<ParserFactory>,
pub tool_parser_registry: Option<&'static ParserRegistry>, pub tool_parser_factory: Option<ToolParserFactory>,
pub worker_registry: Arc<WorkerRegistry>, pub worker_registry: Arc<WorkerRegistry>,
pub policy_registry: Arc<PolicyRegistry>, pub policy_registry: Arc<PolicyRegistry>,
pub router_manager: Option<Arc<RouterManager>>, pub router_manager: Option<Arc<RouterManager>>,
...@@ -64,7 +64,7 @@ impl AppContext { ...@@ -64,7 +64,7 @@ impl AppContext {
let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests); let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests);
let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens)); let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens));
let (tokenizer, reasoning_parser_factory, tool_parser_registry) = let (tokenizer, reasoning_parser_factory, tool_parser_factory) =
if router_config.connection_mode == ConnectionMode::Grpc { if router_config.connection_mode == ConnectionMode::Grpc {
let tokenizer_path = router_config let tokenizer_path = router_config
.tokenizer_path .tokenizer_path
...@@ -80,9 +80,9 @@ impl AppContext { ...@@ -80,9 +80,9 @@ impl AppContext {
.map_err(|e| format!("Failed to create tokenizer: {e}"))?, .map_err(|e| format!("Failed to create tokenizer: {e}"))?,
); );
let reasoning_parser_factory = Some(ParserFactory::new()); let reasoning_parser_factory = Some(ParserFactory::new());
let tool_parser_registry = Some(ParserRegistry::new()); let tool_parser_factory = Some(ToolParserFactory::new());
(tokenizer, reasoning_parser_factory, tool_parser_registry) (tokenizer, reasoning_parser_factory, tool_parser_factory)
} else { } else {
(None, None, None) (None, None, None)
}; };
...@@ -121,7 +121,7 @@ impl AppContext { ...@@ -121,7 +121,7 @@ impl AppContext {
rate_limiter, rate_limiter,
tokenizer, tokenizer,
reasoning_parser_factory, reasoning_parser_factory,
tool_parser_registry, tool_parser_factory,
worker_registry, worker_registry,
policy_registry, policy_registry,
router_manager, router_manager,
......
...@@ -539,7 +539,7 @@ mod tests { ...@@ -539,7 +539,7 @@ mod tests {
)), )),
tokenizer: None, tokenizer: None,
reasoning_parser_factory: None, reasoning_parser_factory: None,
tool_parser_registry: None, tool_parser_factory: None,
router_manager: None, router_manager: None,
response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()), response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()),
load_monitor: None, load_monitor: None,
......
// Factory and pool for creating model-specific tool parsers with pooling support.
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tokio::sync::Mutex;
use crate::tool_parser::parsers::{
DeepSeekParser, Glm4MoeParser, GptOssHarmonyParser, GptOssParser, JsonParser, KimiK2Parser,
LlamaParser, MistralParser, PythonicParser, QwenParser, Step3Parser,
};
use crate::tool_parser::traits::ToolParser;
/// Type alias for pooled parser instances.
pub type PooledToolParser = Arc<Mutex<Box<dyn ToolParser>>>;
/// Type alias for parser creator functions.
type ParserCreator = Arc<dyn Fn() -> Box<dyn ToolParser> + Send + Sync>;
/// Registry for model-specific tool parsers with pooling support.
#[derive(Clone)]
pub struct ToolParserRegistry {
/// Creator functions for parsers (used when pool is empty)
creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
/// Pooled parser instances for reuse
pool: Arc<RwLock<HashMap<String, PooledToolParser>>>,
/// Model pattern to parser name mappings
model_mapping: Arc<RwLock<HashMap<String, String>>>,
/// Default parser name
default_parser: Arc<RwLock<String>>,
}
impl ToolParserRegistry {
/// Create a new empty registry.
pub fn new() -> Self {
Self {
creators: Arc::new(RwLock::new(HashMap::new())),
pool: Arc::new(RwLock::new(HashMap::new())),
model_mapping: Arc::new(RwLock::new(HashMap::new())),
default_parser: Arc::new(RwLock::new("json".to_string())),
}
}
/// Register a parser creator for a given parser type.
pub fn register_parser<F>(&self, name: &str, creator: F)
where
F: Fn() -> Box<dyn ToolParser> + Send + Sync + 'static,
{
let mut creators = self.creators.write().unwrap();
creators.insert(name.to_string(), Arc::new(creator));
}
/// Map a model name/pattern to a parser
pub fn map_model(&self, model: impl Into<String>, parser: impl Into<String>) {
let mut mapping = self.model_mapping.write().unwrap();
mapping.insert(model.into(), parser.into());
}
/// Get a pooled parser by exact name.
/// Returns a shared parser instance from the pool, creating one if needed.
pub fn get_pooled_parser(&self, name: &str) -> Option<PooledToolParser> {
// First check if we have a pooled instance
{
let pool = self.pool.read().unwrap();
if let Some(parser) = pool.get(name) {
return Some(Arc::clone(parser));
}
}
// If not in pool, create one and add to pool
let creators = self.creators.read().unwrap();
if let Some(creator) = creators.get(name) {
let parser = Arc::new(Mutex::new(creator()));
// Add to pool for future use
let mut pool = self.pool.write().unwrap();
pool.insert(name.to_string(), Arc::clone(&parser));
Some(parser)
} else {
None
}
}
/// Get parser for a specific model
pub fn get_pooled_for_model(&self, model: &str) -> Option<PooledToolParser> {
// Try exact match first
{
let mapping = self.model_mapping.read().unwrap();
if let Some(parser_name) = mapping.get(model) {
if let Some(parser) = self.get_pooled_parser(parser_name) {
return Some(parser);
}
}
}
// Try prefix matching with more specific patterns first
let model_mapping = self.model_mapping.read().unwrap();
let best_match = model_mapping
.iter()
.filter(|(pattern, _)| {
pattern.ends_with('*') && model.starts_with(&pattern[..pattern.len() - 1])
})
.max_by_key(|(pattern, _)| pattern.len());
// Return the best matching parser
if let Some((_, parser_name)) = best_match {
if let Some(parser) = self.get_pooled_parser(parser_name) {
return Some(parser);
}
}
// Fall back to default parser
let default = self.default_parser.read().unwrap().clone();
self.get_pooled_parser(&default)
}
/// Clear the parser pool, forcing new instances to be created.
pub fn clear_pool(&self) {
let mut pool = self.pool.write().unwrap();
pool.clear();
}
/// Set the default parser
pub fn set_default_parser(&self, name: impl Into<String>) {
let mut default = self.default_parser.write().unwrap();
*default = name.into();
}
}
impl Default for ToolParserRegistry {
fn default() -> Self {
Self::new()
}
}
/// Factory for creating tool parsers based on model type.
#[derive(Clone)]
pub struct ToolParserFactory {
registry: ToolParserRegistry,
}
impl ToolParserFactory {
/// Create a new factory with default parsers registered.
pub fn new() -> Self {
let registry = ToolParserRegistry::new();
// Register default parsers
registry.register_parser("json", || Box::new(JsonParser::new()));
registry.register_parser("mistral", || Box::new(MistralParser::new()));
registry.register_parser("qwen", || Box::new(QwenParser::new()));
registry.register_parser("pythonic", || Box::new(PythonicParser::new()));
registry.register_parser("llama", || Box::new(LlamaParser::new()));
registry.register_parser("deepseek", || Box::new(DeepSeekParser::new()));
registry.register_parser("glm4_moe", || Box::new(Glm4MoeParser::new()));
registry.register_parser("step3", || Box::new(Step3Parser::new()));
registry.register_parser("kimik2", || Box::new(KimiK2Parser::new()));
// Register GPT-OSS parsers
registry.register_parser("gpt_oss_legacy", || Box::new(GptOssParser::new()));
registry.register_parser("gpt_oss_harmony", || Box::new(GptOssHarmonyParser::new()));
// Choose which GPT-OSS variant to use as default
if use_harmony_gpt_oss() {
registry.register_parser("gpt_oss", || Box::new(GptOssHarmonyParser::new()));
} else {
registry.register_parser("gpt_oss", || Box::new(GptOssParser::new()));
}
// Register default model mappings
Self::register_default_mappings(&registry);
Self { registry }
}
fn register_default_mappings(registry: &ToolParserRegistry) {
// OpenAI models
registry.map_model("gpt-4*", "json");
registry.map_model("gpt-3.5*", "json");
registry.map_model("gpt-4o*", "json");
// Anthropic models
registry.map_model("claude-*", "json");
// Mistral models
registry.map_model("mistral-*", "mistral");
registry.map_model("mixtral-*", "mistral");
// Qwen models
registry.map_model("qwen*", "qwen");
registry.map_model("Qwen*", "qwen");
// Llama models
registry.map_model("llama-4*", "pythonic");
registry.map_model("meta-llama-4*", "pythonic");
registry.map_model("llama-3.2*", "llama");
registry.map_model("meta-llama-3.2*", "llama");
registry.map_model("llama-*", "json");
registry.map_model("meta-llama-*", "json");
// DeepSeek models
registry.map_model("deepseek-v3*", "deepseek");
registry.map_model("deepseek-ai/DeepSeek-V3*", "deepseek");
registry.map_model("deepseek-*", "pythonic");
// GLM models
registry.map_model("glm-4.5*", "glm4_moe");
registry.map_model("glm-4.6*", "glm4_moe");
registry.map_model("glm-*", "json");
// Step3 models
registry.map_model("step3*", "step3");
registry.map_model("Step-3*", "step3");
// Kimi models
registry.map_model("kimi-k2*", "kimik2");
registry.map_model("Kimi-K2*", "kimik2");
registry.map_model("moonshot*/Kimi-K2*", "kimik2");
// GPT-OSS models
registry.map_model("gpt-oss*", "gpt_oss");
registry.map_model("t4-*", "gpt_oss");
// Other models
registry.map_model("gemini-*", "json");
registry.map_model("palm-*", "json");
registry.map_model("gemma-*", "json");
}
/// Get a pooled parser for the given model ID.
/// Returns a shared instance that can be used concurrently.
/// Falls back to JSON parser if model is not recognized.
pub fn get_pooled(&self, model_id: &str) -> PooledToolParser {
self.registry
.get_pooled_for_model(model_id)
.unwrap_or_else(|| {
// Fallback to JSON parser
self.registry
.get_pooled_parser("json")
.expect("JSON parser should always be registered")
})
}
/// Get the internal registry for custom registration.
pub fn registry(&self) -> &ToolParserRegistry {
&self.registry
}
/// Clear the parser pool.
pub fn clear_pool(&self) {
self.registry.clear_pool();
}
/// Get a non-pooled parser for the given model ID (creates a fresh instance each time).
/// This is useful for benchmarks and testing where you want independent parser instances.
pub fn get_parser(&self, model_id: &str) -> Option<Arc<dyn ToolParser>> {
// Determine which parser type to use
let parser_type = {
let mapping = self.registry.model_mapping.read().unwrap();
// Try exact match first
if let Some(parser_name) = mapping.get(model_id) {
parser_name.clone()
} else {
// Try prefix matching
let best_match = mapping
.iter()
.filter(|(pattern, _)| {
pattern.ends_with('*')
&& model_id.starts_with(&pattern[..pattern.len() - 1])
})
.max_by_key(|(pattern, _)| pattern.len());
if let Some((_, parser_name)) = best_match {
parser_name.clone()
} else {
// Fall back to default
self.registry.default_parser.read().unwrap().clone()
}
}
};
let creators = self.registry.creators.read().unwrap();
creators.get(&parser_type).map(|creator| {
// Call the creator to get a Box<dyn ToolParser>, then convert to Arc
let boxed_parser = creator();
Arc::from(boxed_parser)
})
}
/// List all registered parsers (for compatibility with old API).
pub fn list_parsers(&self) -> Vec<String> {
self.registry
.creators
.read()
.unwrap()
.keys()
.cloned()
.collect()
}
}
impl Default for ToolParserFactory {
fn default() -> Self {
Self::new()
}
}
fn use_harmony_gpt_oss() -> bool {
std::env::var("ROUTER_USE_HARMONY_GPT_OSS")
.ok()
.map(|value| {
let normalized = value.trim();
matches!(
normalized,
"1" | "true" | "TRUE" | "True" | "yes" | "YES" | "Yes" | "on" | "ON" | "On"
)
})
.unwrap_or(false)
}
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
/// This module provides infrastructure for parsing tool calls from various model formats. /// This module provides infrastructure for parsing tool calls from various model formats.
// Core modules // Core modules
pub mod errors; pub mod errors;
pub mod factory;
pub mod partial_json; pub mod partial_json;
pub mod registry;
pub mod state; pub mod state;
pub mod traits; pub mod traits;
pub mod types; pub mod types;
...@@ -17,10 +17,9 @@ mod tests; ...@@ -17,10 +17,9 @@ mod tests;
// Re-export commonly used types // Re-export commonly used types
pub use errors::{ToolParserError, ToolParserResult}; pub use errors::{ToolParserError, ToolParserResult};
pub use registry::ParserRegistry; pub use factory::{PooledToolParser, ToolParserFactory, ToolParserRegistry};
pub use state::{ParsePhase, ParseState};
pub use traits::{PartialJsonParser, ToolParser}; pub use traits::{PartialJsonParser, ToolParser};
pub use types::{FunctionCall, PartialToolCall, StreamResult, ToolCall}; pub use types::{FunctionCall, PartialToolCall, StreamingParseResult, ToolCall};
// Re-export parsers for convenience // Re-export parsers for convenience
pub use parsers::{ pub use parsers::{
......
...@@ -2,12 +2,13 @@ use async_trait::async_trait; ...@@ -2,12 +2,13 @@ use async_trait::async_trait;
use regex::Regex; use regex::Regex;
use serde_json::Value; use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult}, errors::{ToolParserError, ToolParserResult},
partial_json::PartialJson, parsers::helpers,
state::ParseState,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall}, types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
}; };
/// DeepSeek V3 format parser for tool calls /// DeepSeek V3 format parser for tool calls
...@@ -20,12 +21,29 @@ use crate::tool_parser::{ ...@@ -20,12 +21,29 @@ use crate::tool_parser::{
/// - JSON arguments in code blocks /// - JSON arguments in code blocks
/// - Support for multiple sequential tool calls /// - Support for multiple sequential tool calls
pub struct DeepSeekParser { pub struct DeepSeekParser {
/// Parser for handling incomplete JSON during streaming
partial_json: PartialJson,
/// Regex for extracting complete tool calls /// Regex for extracting complete tool calls
tool_call_extractor: Regex, tool_call_extractor: Regex,
/// Regex for extracting function details /// Regex for extracting function details
func_detail_extractor: Regex, func_detail_extractor: Regex,
/// Regex for matching partial tool calls during streaming
partial_tool_call_regex: Regex,
/// Regex pattern for removing completed tool calls from buffer
tool_call_end_pattern: Regex,
/// Buffer for accumulating incomplete patterns across chunks
buffer: String,
/// Stores complete tool call info (name and arguments) for each tool being parsed
prev_tool_call_arr: Vec<Value>,
/// Index of currently streaming tool call (-1 means no active tool)
current_tool_id: i32,
/// Flag for whether current tool's name has been sent to client
current_tool_name_sent: bool,
/// Tracks raw JSON string content streamed to client for each tool's arguments
streamed_args_for_tool: Vec<String>,
} }
impl DeepSeekParser { impl DeepSeekParser {
...@@ -38,10 +56,24 @@ impl DeepSeekParser { ...@@ -38,10 +56,24 @@ impl DeepSeekParser {
let func_detail_pattern = r"(?s)<|tool▁call▁begin|>(.*?)<|tool▁sep|>(.*?)\n```json\n(.*?)\n```<|tool▁call▁end|>"; let func_detail_pattern = r"(?s)<|tool▁call▁begin|>(.*?)<|tool▁sep|>(.*?)\n```json\n(.*?)\n```<|tool▁call▁end|>";
let func_detail_extractor = Regex::new(func_detail_pattern).expect("Valid regex pattern"); let func_detail_extractor = Regex::new(func_detail_pattern).expect("Valid regex pattern");
// Partial pattern for streaming - uses .* (greedy) not .*? to match all partial content
let partial_pattern = r"(?s)<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)";
let partial_tool_call_regex = Regex::new(partial_pattern).expect("Valid regex pattern");
// Pattern for removing completed tool calls
let end_pattern = r"(?s)<|tool▁call▁begin|>.*?<|tool▁call▁end|>";
let tool_call_end_pattern = Regex::new(end_pattern).expect("Valid regex pattern");
Self { Self {
partial_json: PartialJson::default(),
tool_call_extractor, tool_call_extractor,
func_detail_extractor, func_detail_extractor,
partial_tool_call_regex,
tool_call_end_pattern,
buffer: String::new(),
prev_tool_call_arr: Vec::new(),
current_tool_id: -1,
current_tool_name_sent: false,
streamed_args_for_tool: Vec::new(),
} }
} }
...@@ -143,107 +175,146 @@ impl ToolParser for DeepSeekParser { ...@@ -143,107 +175,146 @@ impl ToolParser for DeepSeekParser {
} }
async fn parse_incremental( async fn parse_incremental(
&self, &mut self,
chunk: &str, chunk: &str,
state: &mut ParseState, tools: &[Tool],
) -> ToolParserResult<StreamResult> { ) -> ToolParserResult<StreamingParseResult> {
state.buffer.push_str(chunk); self.buffer.push_str(chunk);
let current_text = &self.buffer.clone();
// Check if we have a tool call (either the start token or individual tool call)
let has_tool_call =
self.has_tool_markers(current_text) || current_text.contains("<|tool▁call▁begin|>");
// Check for tool markers if !has_tool_call {
if !self.has_tool_markers(&state.buffer) {
// No tool markers detected - return all buffered content as normal text // No tool markers detected - return all buffered content as normal text
let normal_text = std::mem::take(&mut state.buffer); // Strip out end tokens if present
return Ok(StreamResult::NormalText(normal_text)); let mut normal_text = std::mem::take(&mut self.buffer);
for e_token in ["<|tool▁calls▁end|>", "```", "<|tool▁call▁end|>"] {
normal_text = normal_text.replace(e_token, "");
}
return Ok(StreamingParseResult {
normal_text,
calls: vec![],
});
} }
// Check for text before tool markers and extract it as normal text // Build tool indices for validation
if let Some(marker_pos) = state.buffer.find("<|tool▁calls▁begin|>") { let tool_indices = helpers::get_tool_indices(tools);
if marker_pos > 0 {
// We have text before the tool marker - extract it as normal text let mut calls: Vec<ToolCallItem> = Vec::new();
let normal_text: String = state.buffer.drain(..marker_pos).collect();
return Ok(StreamResult::NormalText(normal_text)); // Try to match the partial tool call pattern
if let Some(captures) = self.partial_tool_call_regex.captures(current_text) {
let func_name = captures.get(2).map_or("", |m| m.as_str()).trim();
let func_args_raw = captures.get(3).map_or("", |m| m.as_str()).trim();
// Validate tool name
if !tool_indices.contains_key(func_name) {
// Invalid tool name - skip this tool, preserve indexing for next tool
tracing::warn!("Invalid tool name '{}' - skipping", func_name);
helpers::reset_current_tool_state(
&mut self.buffer,
&mut self.current_tool_name_sent,
&mut self.streamed_args_for_tool,
&self.prev_tool_call_arr,
);
return Ok(StreamingParseResult::default());
} }
}
// Look for start of tool calls // Initialize state if this is the first tool call
if let Some(start_pos) = state.buffer.find("<|tool▁calls▁begin|>") { if self.current_tool_id == -1 {
// Look for individual tool call start self.current_tool_id = 0;
let search_from = start_pos + "<|tool▁calls▁begin|>".len(); self.prev_tool_call_arr = Vec::new();
if let Some(call_start) = state.buffer[search_from..].find("<|tool▁call▁begin|>") self.streamed_args_for_tool = vec![String::new()];
{ }
let call_start_abs = search_from + call_start;
// Ensure we have enough entries in our tracking arrays
// Look for the end of this tool call helpers::ensure_capacity(
let search_end_from = call_start_abs + "<|tool▁call▁begin|>".len(); self.current_tool_id,
if let Some(call_end) = state.buffer[search_end_from..].find("<|tool▁call▁end|>") &mut self.prev_tool_call_arr,
{ &mut self.streamed_args_for_tool,
let call_end_abs = search_end_from + call_end + "<|tool▁call▁end|>".len(); );
// Extract and parse the complete tool call // Send tool name if not sent yet
let tool_call_text = &state.buffer[call_start_abs..call_end_abs]; if !self.current_tool_name_sent {
calls.push(ToolCallItem {
match self.parse_tool_call(tool_call_text) { tool_index: self.current_tool_id as usize,
Ok(tool) => { name: Some(func_name.to_string()),
// Remove the processed part from buffer parameters: String::new(),
state.buffer.drain(..call_end_abs); });
return Ok(StreamResult::ToolComplete(tool)); self.current_tool_name_sent = true;
}
Err(_) => { // Store the tool call info for serving layer completions endpoint
// Parsing failed, skip this tool call let tool_id = self.current_tool_id as usize;
state.buffer.drain(..call_end_abs); if self.prev_tool_call_arr.len() <= tool_id {
} self.prev_tool_call_arr
.resize_with(tool_id + 1, || Value::Null);
}
self.prev_tool_call_arr[tool_id] = serde_json::json!({
"name": func_name,
"arguments": {},
});
} else {
// Compute incremental diff
let tool_id = self.current_tool_id as usize;
let last_sent = self
.streamed_args_for_tool
.get(tool_id)
.map(|s| s.as_str())
.unwrap_or("");
let argument_diff = func_args_raw
.strip_prefix(last_sent)
.unwrap_or(func_args_raw);
if !argument_diff.is_empty() {
calls.push(ToolCallItem {
tool_index: tool_id,
name: None,
parameters: argument_diff.to_string(),
});
if tool_id < self.streamed_args_for_tool.len() {
self.streamed_args_for_tool[tool_id].push_str(argument_diff);
} }
} else { }
// Tool call not complete yet, try to extract partial info
let partial = &state.buffer[search_end_from..]; // Check if JSON is complete
if helpers::is_complete_json(func_args_raw) {
// Try to extract function name // Update the stored arguments
if let Some(sep_pos) = partial.find("<|tool▁sep|>") { if let Ok(parsed_args) = serde_json::from_str::<Value>(func_args_raw) {
if let Some(_func_start) = partial[..sep_pos].rfind("function") { let tool_id = self.current_tool_id as usize;
// We have the function type marker if tool_id < self.prev_tool_call_arr.len() {
let after_sep = &partial[sep_pos + "<|tool▁sep|>".len()..]; if let Some(obj) = self.prev_tool_call_arr[tool_id].as_object_mut() {
obj.insert("arguments".to_string(), parsed_args);
// Look for function name (ends at newline before ```json)
if let Some(name_end) = after_sep.find("\n```json\n") {
let func_name = after_sep[..name_end].trim();
if !state.in_string {
state.in_string = true; // Mark name as sent
return Ok(StreamResult::ToolName {
index: 0,
name: func_name.to_string(),
});
}
// Try to extract partial arguments
let args_start = name_end + "\n```json\n".len();
let partial_args = &after_sep[args_start..];
// Check if we can parse partial JSON
if !partial_args.is_empty() {
match self.partial_json.parse_value(partial_args) {
Ok((value, _consumed)) => {
let args_str = serde_json::to_string(&value)
.unwrap_or_else(|_| "{}".to_string());
return Ok(StreamResult::ToolArguments {
index: 0,
arguments: args_str,
});
}
Err(_) => {
// Can't parse yet, continue waiting for more data
}
}
}
} }
} }
} }
// Find the end of the current tool call and remove only that part from buffer
if let Some(mat) = self.tool_call_end_pattern.find(current_text) {
// Remove the completed tool call from buffer, keep any remaining content
self.buffer = current_text[mat.end()..].to_string();
} else {
self.buffer.clear();
}
let result = StreamingParseResult {
normal_text: String::new(),
calls,
};
self.current_tool_id += 1;
self.current_tool_name_sent = false;
return Ok(result);
} }
} }
} }
Ok(StreamResult::Incomplete) Ok(StreamingParseResult {
normal_text: String::new(),
calls,
})
} }
fn detect_format(&self, text: &str) -> bool { fn detect_format(&self, text: &str) -> bool {
......
...@@ -2,11 +2,13 @@ use async_trait::async_trait; ...@@ -2,11 +2,13 @@ use async_trait::async_trait;
use regex::Regex; use regex::Regex;
use serde_json::Value; use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult}, errors::{ToolParserError, ToolParserResult},
state::ParseState, parsers::helpers,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall}, types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
}; };
/// GLM-4 MoE format parser for tool calls /// GLM-4 MoE format parser for tool calls
...@@ -25,6 +27,22 @@ pub struct Glm4MoeParser { ...@@ -25,6 +27,22 @@ pub struct Glm4MoeParser {
func_detail_extractor: Regex, func_detail_extractor: Regex,
/// Regex for extracting argument key-value pairs /// Regex for extracting argument key-value pairs
arg_extractor: Regex, arg_extractor: Regex,
/// Buffer for accumulating incomplete patterns across chunks
buffer: String,
/// Stores complete tool call info (name and arguments) for each tool being parsed
prev_tool_call_arr: Vec<Value>,
/// Index of currently streaming tool call (-1 means no active tool)
current_tool_id: i32,
/// Tracks raw JSON string content streamed to client for each tool's arguments
streamed_args_for_tool: Vec<String>,
/// Token configuration
bot_token: &'static str,
eot_token: &'static str,
} }
impl Glm4MoeParser { impl Glm4MoeParser {
...@@ -44,12 +62,18 @@ impl Glm4MoeParser { ...@@ -44,12 +62,18 @@ impl Glm4MoeParser {
tool_call_extractor, tool_call_extractor,
func_detail_extractor, func_detail_extractor,
arg_extractor, arg_extractor,
buffer: String::new(),
prev_tool_call_arr: Vec::new(),
current_tool_id: -1,
streamed_args_for_tool: Vec::new(),
bot_token: "<tool_call>",
eot_token: "</tool_call>",
} }
} }
/// Check if text contains GLM-4 MoE tool markers /// Check if text contains GLM-4 MoE tool markers
fn has_tool_markers(&self, text: &str) -> bool { fn has_tool_markers(&self, text: &str) -> bool {
text.contains("<tool_call>") text.contains(self.bot_token)
} }
/// Parse arguments from key-value pairs /// Parse arguments from key-value pairs
...@@ -120,6 +144,25 @@ impl Glm4MoeParser { ...@@ -120,6 +144,25 @@ impl Glm4MoeParser {
Ok(None) Ok(None)
} }
} }
/// Parse and return StreamingParseResult (mirrors Python's detect_and_parse)
/// Parse all tool calls from text (shared logic for complete and incremental parsing)
fn parse_tool_calls_from_text(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
let mut tools = Vec::new();
for mat in self.tool_call_extractor.find_iter(text) {
match self.parse_tool_call(mat.as_str()) {
Ok(Some(tool)) => tools.push(tool),
Ok(None) => continue,
Err(e) => {
tracing::warn!("Failed to parse tool call: {}", e);
continue;
}
}
}
Ok(tools)
}
} }
impl Default for Glm4MoeParser { impl Default for Glm4MoeParser {
...@@ -140,18 +183,8 @@ impl ToolParser for Glm4MoeParser { ...@@ -140,18 +183,8 @@ impl ToolParser for Glm4MoeParser {
let idx = text.find("<tool_call>").unwrap(); let idx = text.find("<tool_call>").unwrap();
let normal_text = text[..idx].to_string(); let normal_text = text[..idx].to_string();
// Extract tool calls // Parse all tool calls using shared helper
let mut tools = Vec::new(); let tools = self.parse_tool_calls_from_text(text)?;
for mat in self.tool_call_extractor.find_iter(text) {
match self.parse_tool_call(mat.as_str()) {
Ok(Some(tool)) => tools.push(tool),
Ok(None) => continue,
Err(e) => {
tracing::warn!("Failed to parse tool call: {}", e);
continue;
}
}
}
// If no tools were successfully parsed despite having markers, return entire text as fallback // If no tools were successfully parsed despite having markers, return entire text as fallback
if tools.is_empty() { if tools.is_empty() {
...@@ -162,78 +195,127 @@ impl ToolParser for Glm4MoeParser { ...@@ -162,78 +195,127 @@ impl ToolParser for Glm4MoeParser {
} }
async fn parse_incremental( async fn parse_incremental(
&self, &mut self,
chunk: &str, chunk: &str,
state: &mut ParseState, tools: &[Tool],
) -> ToolParserResult<StreamResult> { ) -> ToolParserResult<StreamingParseResult> {
state.buffer.push_str(chunk); // Python logic: Wait for complete tool call, then parse it all at once
self.buffer.push_str(chunk);
// Check for tool markers let current_text = &self.buffer.clone();
if !self.has_tool_markers(&state.buffer) {
// No tool markers detected - return all buffered content as normal text // Check if we have bot_token
let normal_text = std::mem::take(&mut state.buffer); let start = current_text.find(self.bot_token);
return Ok(StreamResult::NormalText(normal_text)); if start.is_none() {
self.buffer.clear();
// If we're in the middle of streaming (current_tool_id > 0), don't return text
let normal_text = if self.current_tool_id > 0 {
String::new()
} else {
current_text.clone()
};
return Ok(StreamingParseResult {
normal_text,
calls: vec![],
});
} }
// Check for text before tool markers and extract it as normal text // Check if we have eot_token (end of tool call)
if let Some(marker_pos) = state.buffer.find("<tool_call>") { let end = current_text.find(self.eot_token);
if marker_pos > 0 { if let Some(end_pos) = end {
// We have text before the tool marker - extract it as normal text // We have a complete tool call!
let normal_text: String = state.buffer.drain(..marker_pos).collect();
return Ok(StreamResult::NormalText(normal_text)); // Initialize state if this is the first tool call
if self.current_tool_id == -1 {
self.current_tool_id = 0;
self.prev_tool_call_arr = Vec::new();
self.streamed_args_for_tool = vec![String::new()];
} }
}
// Look for start of tool call // Ensure we have enough entries in our tracking arrays
if let Some(start_pos) = state.buffer.find("<tool_call>") { helpers::ensure_capacity(
// Look for the end of this tool call self.current_tool_id,
let search_from = start_pos + "<tool_call>".len(); &mut self.prev_tool_call_arr,
if let Some(end_pos) = state.buffer[search_from..].find("</tool_call>") { &mut self.streamed_args_for_tool,
let end_abs = search_from + end_pos + "</tool_call>".len(); );
// Parse the complete block using shared helper
let block_end = end_pos + self.eot_token.len();
let parsed_tools = self.parse_tool_calls_from_text(&current_text[..block_end])?;
// Extract normal text before tool calls
let idx = current_text.find(self.bot_token);
let normal_text = if let Some(pos) = idx {
current_text[..pos].trim().to_string()
} else {
String::new()
};
// Extract and parse the complete tool call // Build tool indices for validation
let tool_call_text = &state.buffer[start_pos..end_abs]; let tool_indices = helpers::get_tool_indices(tools);
let mut calls = Vec::new();
if !parsed_tools.is_empty() {
// Take the first tool and convert to ToolCallItem
let tool_call = &parsed_tools[0];
let tool_id = self.current_tool_id as usize;
// Validate tool name
if !tool_indices.contains_key(&tool_call.function.name) {
// Invalid tool name - skip this tool, preserve indexing for next tool
tracing::warn!("Invalid tool name '{}' - skipping", tool_call.function.name);
helpers::reset_current_tool_state(
&mut self.buffer,
&mut false, // glm4_moe doesn't track name_sent per tool
&mut self.streamed_args_for_tool,
&self.prev_tool_call_arr,
);
return Ok(StreamingParseResult::default());
}
if let Some(tool) = self.parse_tool_call(tool_call_text)? { calls.push(ToolCallItem {
// Remove the processed part from buffer tool_index: tool_id,
state.buffer.drain(..end_abs); name: Some(tool_call.function.name.clone()),
parameters: tool_call.function.arguments.clone(),
});
return Ok(StreamResult::ToolComplete(tool)); // Store in tracking arrays
if self.prev_tool_call_arr.len() <= tool_id {
self.prev_tool_call_arr
.resize_with(tool_id + 1, || Value::Null);
} }
} else {
// Tool call not complete yet, try to extract partial info
let partial = &state.buffer[search_from..];
// Try to extract function name (first line after <tool_call>)
if let Some(name_end) = partial.find('\n') {
let func_name = partial[..name_end].trim();
if !func_name.is_empty() && !state.in_string {
state.in_string = true; // Mark name as sent
return Ok(StreamResult::ToolName {
index: 0,
name: func_name.to_string(),
});
}
// Try to extract partial arguments
let args_text = &partial[name_end + 1..];
let partial_args = self.parse_arguments(args_text)?;
if !partial_args.is_empty() { // Parse parameters as JSON and store
let args_str = serde_json::to_string(&partial_args) if let Ok(args) = serde_json::from_str::<Value>(&tool_call.function.arguments) {
.unwrap_or_else(|_| "{}".to_string()); self.prev_tool_call_arr[tool_id] = serde_json::json!({
"name": tool_call.function.name,
"arguments": args,
});
}
return Ok(StreamResult::ToolArguments { if self.streamed_args_for_tool.len() <= tool_id {
index: 0, self.streamed_args_for_tool
arguments: args_str, .resize_with(tool_id + 1, String::new);
});
}
} }
self.streamed_args_for_tool[tool_id] = tool_call.function.arguments.clone();
self.current_tool_id += 1;
} }
// Remove processed portion from buffer
self.buffer = current_text[block_end..].to_string();
return Ok(StreamingParseResult { normal_text, calls });
} }
Ok(StreamResult::Incomplete) // No complete tool call yet - return normal text before start token
let start_pos = start.unwrap();
let normal_text = current_text[..start_pos].to_string();
self.buffer = current_text[start_pos..].to_string();
Ok(StreamingParseResult {
normal_text,
calls: vec![],
})
} }
fn detect_format(&self, text: &str) -> bool { fn detect_format(&self, text: &str) -> bool {
......
use async_trait::async_trait; use async_trait::async_trait;
use crate::protocols::spec::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::ToolParserResult, errors::ToolParserResult,
state::ParseState,
traits::{TokenToolParser, ToolParser}, traits::{TokenToolParser, ToolParser},
types::{StreamResult, ToolCall}, types::{StreamingParseResult, ToolCall},
}; };
/// Placeholder for the Harmony-backed GPT-OSS parser. /// Placeholder for the Harmony-backed GPT-OSS parser.
...@@ -29,12 +30,12 @@ impl ToolParser for GptOssHarmonyParser { ...@@ -29,12 +30,12 @@ impl ToolParser for GptOssHarmonyParser {
} }
async fn parse_incremental( async fn parse_incremental(
&self, &mut self,
_chunk: &str, _chunk: &str,
_state: &mut ParseState, _tools: &[Tool],
) -> ToolParserResult<StreamResult> { ) -> ToolParserResult<StreamingParseResult> {
// Temporary stub until the Harmony streaming pipeline is implemented. // Temporary stub until the Harmony streaming pipeline is implemented.
Ok(StreamResult::Incomplete) Ok(StreamingParseResult::default())
} }
fn detect_format(&self, text: &str) -> bool { fn detect_format(&self, text: &str) -> bool {
...@@ -61,10 +62,10 @@ impl TokenToolParser for GptOssHarmonyParser { ...@@ -61,10 +62,10 @@ impl TokenToolParser for GptOssHarmonyParser {
} }
async fn parse_incremental_tokens( async fn parse_incremental_tokens(
&self, &mut self,
_tokens: &[u32], _tokens: &[u32],
_state: &mut ParseState, _tools: &[Tool],
) -> ToolParserResult<StreamResult> { ) -> ToolParserResult<StreamingParseResult> {
Ok(StreamResult::Incomplete) Ok(StreamingParseResult::default())
} }
} }
...@@ -2,12 +2,14 @@ use async_trait::async_trait; ...@@ -2,12 +2,14 @@ use async_trait::async_trait;
use regex::Regex; use regex::Regex;
use serde_json::Value; use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult}, errors::{ToolParserError, ToolParserResult},
parsers::helpers,
partial_json::PartialJson, partial_json::PartialJson,
state::ParseState,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall}, types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
}; };
/// GPT-OSS format parser for tool calls /// GPT-OSS format parser for tool calls
...@@ -26,6 +28,11 @@ pub struct GptOssParser { ...@@ -26,6 +28,11 @@ pub struct GptOssParser {
function_call_extractor: Regex, function_call_extractor: Regex,
/// Regex for extracting streaming function calls /// Regex for extracting streaming function calls
streaming_extractor: Regex, streaming_extractor: Regex,
/// Buffer for accumulating chunks
buffer: String,
/// Whether the tool name has been sent (for streaming)
name_sent: bool,
} }
impl GptOssParser { impl GptOssParser {
...@@ -45,6 +52,9 @@ impl GptOssParser { ...@@ -45,6 +52,9 @@ impl GptOssParser {
partial_json: PartialJson::default(), partial_json: PartialJson::default(),
function_call_extractor, function_call_extractor,
streaming_extractor, streaming_extractor,
buffer: String::new(),
name_sent: false,
} }
} }
...@@ -123,21 +133,21 @@ impl ToolParser for GptOssParser { ...@@ -123,21 +133,21 @@ impl ToolParser for GptOssParser {
} }
async fn parse_incremental( async fn parse_incremental(
&self, &mut self,
chunk: &str, chunk: &str,
state: &mut ParseState, tools: &[Tool],
) -> ToolParserResult<StreamResult> { ) -> ToolParserResult<StreamingParseResult> {
state.buffer.push_str(chunk); self.buffer.push_str(chunk);
// Check for tool markers // Check for tool markers
if !self.has_tool_markers(&state.buffer) { if !self.has_tool_markers(&self.buffer) {
// No markers found, clear buffer and return // No markers found, clear buffer and return
state.buffer.clear(); self.buffer.clear();
return Ok(StreamResult::Incomplete); return Ok(StreamingParseResult::default());
} }
// Try to match streaming pattern // Try to match streaming pattern
if let Some(captures) = self.streaming_extractor.captures(&state.buffer) { if let Some(captures) = self.streaming_extractor.captures(&self.buffer) {
if let (Some(name_match), Some(args_match)) = (captures.get(1), captures.get(2)) { if let (Some(name_match), Some(args_match)) = (captures.get(1), captures.get(2)) {
let full_function_name = name_match.as_str(); let full_function_name = name_match.as_str();
let partial_args = args_match.as_str(); let partial_args = args_match.as_str();
...@@ -146,16 +156,30 @@ impl ToolParser for GptOssParser { ...@@ -146,16 +156,30 @@ impl ToolParser for GptOssParser {
let function_name = self.extract_function_name(full_function_name); let function_name = self.extract_function_name(full_function_name);
// Send function name if not sent yet // Send function name if not sent yet
if !state.in_string { if !self.name_sent {
state.in_string = true; // Mark name as sent // Validate tool name
return Ok(StreamResult::ToolName { let tool_indices = helpers::get_tool_indices(tools);
index: 0, if !tool_indices.contains_key(&function_name) {
name: function_name.clone(), // Invalid tool name - skip
tracing::warn!("Invalid tool name '{}' - skipping", function_name);
self.buffer.clear();
self.name_sent = false;
return Ok(StreamingParseResult::default());
}
self.name_sent = true; // Mark name as sent
return Ok(StreamingParseResult {
normal_text: String::new(),
calls: vec![ToolCallItem {
tool_index: 0,
name: Some(function_name.clone()),
parameters: String::new(),
}],
}); });
} }
// Check if we have a complete function call // Check if we have a complete function call
if let Some(complete_match) = self.function_call_extractor.captures(&state.buffer) { if let Some(complete_match) = self.function_call_extractor.captures(&self.buffer) {
if let Some(args_match) = complete_match.get(2) { if let Some(args_match) = complete_match.get(2) {
let args_content = args_match.as_str().trim(); let args_content = args_match.as_str().trim();
...@@ -170,26 +194,22 @@ impl ToolParser for GptOssParser { ...@@ -170,26 +194,22 @@ impl ToolParser for GptOssParser {
} }
}; };
// Generate unique ID
let id = format!("gpt_oss_call_{}", uuid::Uuid::new_v4());
let tool = ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: function_name,
arguments,
},
};
// Remove the processed part from buffer // Remove the processed part from buffer
let complete_end = complete_match.get(0).unwrap().end(); let complete_end = complete_match.get(0).unwrap().end();
state.buffer.drain(..complete_end); self.buffer.drain(..complete_end);
// Reset state for next tool // Reset state for next tool
state.in_string = false; self.name_sent = false;
return Ok(StreamResult::ToolComplete(tool)); // Return final arguments
return Ok(StreamingParseResult {
normal_text: String::new(),
calls: vec![ToolCallItem {
tool_index: 0,
name: None,
parameters: arguments,
}],
});
} }
} else { } else {
// Try to parse partial JSON for streaming arguments // Try to parse partial JSON for streaming arguments
...@@ -206,9 +226,13 @@ impl ToolParser for GptOssParser { ...@@ -206,9 +226,13 @@ impl ToolParser for GptOssParser {
let args_str = serde_json::to_string(&value) let args_str = serde_json::to_string(&value)
.unwrap_or_else(|_| "{}".to_string()); .unwrap_or_else(|_| "{}".to_string());
return Ok(StreamResult::ToolArguments { return Ok(StreamingParseResult {
index: 0, normal_text: String::new(),
arguments: args_str, calls: vec![ToolCallItem {
tool_index: 0,
name: None,
parameters: args_str,
}],
}); });
} }
Err(_) => { Err(_) => {
...@@ -220,7 +244,7 @@ impl ToolParser for GptOssParser { ...@@ -220,7 +244,7 @@ impl ToolParser for GptOssParser {
} }
} }
Ok(StreamResult::Incomplete) Ok(StreamingParseResult::default())
} }
fn detect_format(&self, text: &str) -> bool { fn detect_format(&self, text: &str) -> bool {
......
use crate::protocols::spec::Tool;
use serde_json::Value;
use std::collections::HashMap;
use crate::tool_parser::errors::{ToolParserError, ToolParserResult};
use crate::tool_parser::types::{StreamingParseResult, ToolCallItem};
/// Get a mapping of tool names to their indices
pub fn get_tool_indices(tools: &[Tool]) -> HashMap<String, usize> {
tools
.iter()
.enumerate()
.map(|(i, tool)| (tool.function.name.clone(), i))
.collect()
}
/// Check if a buffer ends with a partial occurrence of a token
/// Returns Some(length) if there's a partial match, None otherwise
pub fn ends_with_partial_token(buffer: &str, token: &str) -> Option<usize> {
if buffer.is_empty() || token.is_empty() {
return None;
}
(1..token.len()).find(|&i| buffer.ends_with(&token[..i]))
}
/// Reset state for the current tool being parsed (used when skipping invalid tools).
/// This preserves the parser's overall state (current_tool_id, prev_tool_call_arr)
/// but clears the state specific to the current incomplete tool.
pub fn reset_current_tool_state(
buffer: &mut String,
current_tool_name_sent: &mut bool,
streamed_args_for_tool: &mut Vec<String>,
prev_tool_call_arr: &[Value],
) {
buffer.clear();
*current_tool_name_sent = false;
// Only pop if we added an entry for the current (invalid) tool
// streamed_args_for_tool should match prev_tool_call_arr length for completed tools
if streamed_args_for_tool.len() > prev_tool_call_arr.len() {
streamed_args_for_tool.pop();
}
}
/// Reset the entire parser state (used at the start of a new request).
/// Clears all accumulated tool calls and resets all state to initial values.
pub fn reset_parser_state(
buffer: &mut String,
prev_tool_call_arr: &mut Vec<Value>,
current_tool_id: &mut i32,
current_tool_name_sent: &mut bool,
streamed_args_for_tool: &mut Vec<String>,
) {
buffer.clear();
prev_tool_call_arr.clear();
*current_tool_id = 0;
*current_tool_name_sent = false;
streamed_args_for_tool.clear();
}
/// Ensure arrays have capacity for the given tool ID
pub fn ensure_capacity(
current_tool_id: i32,
prev_tool_call_arr: &mut Vec<Value>,
streamed_args_for_tool: &mut Vec<String>,
) {
if current_tool_id < 0 {
return;
}
let needed = (current_tool_id + 1) as usize;
if prev_tool_call_arr.len() < needed {
prev_tool_call_arr.resize_with(needed, || Value::Null);
}
if streamed_args_for_tool.len() < needed {
streamed_args_for_tool.resize_with(needed, String::new);
}
}
/// Check if a string contains complete, valid JSON
pub fn is_complete_json(input: &str) -> bool {
serde_json::from_str::<Value>(input).is_ok()
}
/// Normalize the arguments/parameters field in a tool call object.
/// If the object has "parameters" but not "arguments", copy parameters to arguments.
///
/// # Background
/// Different LLM formats use different field names:
/// - Llama and JSON parsers use "parameters" (correct per JSON Schema spec)
/// - Mistral and Qwen use "arguments"
///
/// This function normalizes to "arguments" for consistent downstream processing.
pub fn normalize_arguments_field(mut obj: Value) -> Value {
if obj.get("arguments").is_none() {
if let Some(params) = obj.get("parameters").cloned() {
if let Value::Object(ref mut map) = obj {
map.insert("arguments".to_string(), params);
}
}
}
obj
}
/// Handle the entire JSON tool call streaming process for JSON-based parsers.
///
/// This unified function handles all aspects of streaming tool calls:
/// - Parsing partial JSON from the buffer
/// - Validating tool names against available tools
/// - Streaming tool names (Case 1)
/// - Streaming tool arguments (Case 2)
/// - Managing parser state and buffer updates
///
/// Used by JSON, Llama, Mistral, and Qwen parsers.
///
/// # Parameters
/// - `current_text`: The current buffered text being parsed
/// - `start_idx`: Start index of JSON content in current_text
/// - `partial_json`: Mutable reference to partial JSON parser
/// - `tool_indices`: Map of valid tool names to their indices
/// - `buffer`: Mutable parser buffer
/// - `current_tool_id`: Mutable current tool index (-1 means no active tool)
/// - `current_tool_name_sent`: Mutable flag for whether current tool's name was sent
/// - `streamed_args_for_tool`: Mutable accumulator of streamed arguments per tool
/// - `prev_tool_call_arr`: Mutable array of previous tool call states
///
/// # Returns
/// - `Ok(StreamingParseResult)` with any tool call items to stream
/// - `Err(ToolParserError)` if JSON parsing or serialization fails
#[allow(clippy::too_many_arguments)]
pub fn handle_json_tool_streaming(
current_text: &str,
start_idx: usize,
partial_json: &mut crate::tool_parser::partial_json::PartialJson,
tool_indices: &HashMap<String, usize>,
buffer: &mut String,
current_tool_id: &mut i32,
current_tool_name_sent: &mut bool,
streamed_args_for_tool: &mut Vec<String>,
prev_tool_call_arr: &mut Vec<Value>,
) -> ToolParserResult<StreamingParseResult> {
// Check if we have content to parse
if start_idx >= current_text.len() {
return Ok(StreamingParseResult::default());
}
// Extract JSON string from current position
let json_str = &current_text[start_idx..];
// Parse partial JSON
let (obj, end_idx) = match partial_json.parse_value(json_str) {
Ok(result) => result,
Err(_) => {
return Ok(StreamingParseResult::default());
}
};
// Check if JSON is complete
let is_complete = end_idx == json_str.len() && serde_json::from_str::<Value>(json_str).is_ok();
// Validate tool name if present
if let Some(name) = obj.get("name").and_then(|v| v.as_str()) {
if !tool_indices.contains_key(name) {
// Invalid tool name - skip this tool, preserve indexing for next tool
tracing::warn!("Invalid tool name '{}' - skipping", name);
reset_current_tool_state(
buffer,
current_tool_name_sent,
streamed_args_for_tool,
prev_tool_call_arr,
);
return Ok(StreamingParseResult::default());
}
}
// Normalize parameters/arguments field
let current_tool_call = normalize_arguments_field(obj);
let mut result = StreamingParseResult::default();
// Case 1: Handle tool name streaming
if !*current_tool_name_sent {
if let Some(function_name) = current_tool_call.get("name").and_then(|v| v.as_str()) {
if tool_indices.contains_key(function_name) {
// Initialize if first tool
if *current_tool_id == -1 {
*current_tool_id = 0;
streamed_args_for_tool.push(String::new());
} else if *current_tool_id as usize >= streamed_args_for_tool.len() {
// Ensure capacity for subsequent tools
ensure_capacity(*current_tool_id, prev_tool_call_arr, streamed_args_for_tool);
}
// Send tool name with empty parameters
*current_tool_name_sent = true;
result.calls.push(ToolCallItem {
tool_index: *current_tool_id as usize,
name: Some(function_name.to_string()),
parameters: String::new(),
});
}
}
}
// Case 2: Handle streaming arguments
else if let Some(cur_arguments) = current_tool_call.get("arguments") {
let tool_id = *current_tool_id as usize;
let sent = streamed_args_for_tool
.get(tool_id)
.map(|s| s.len())
.unwrap_or(0);
let cur_args_json = serde_json::to_string(cur_arguments)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Compute diff: everything after what we've already sent
let diff = cur_args_json[sent..].to_string();
// Send diff if there's new content
if !diff.is_empty() {
// Only accumulate if not complete
if !is_complete && tool_id < streamed_args_for_tool.len() {
streamed_args_for_tool[tool_id].push_str(&diff);
}
result.calls.push(ToolCallItem {
tool_index: tool_id,
name: None,
parameters: diff,
});
}
// If JSON is complete, advance to next tool
if is_complete {
// Remove processed portion, keep unprocessed content
*buffer = current_text[start_idx + end_idx..].to_string();
// Clear completed tool data
if tool_id < prev_tool_call_arr.len() {
prev_tool_call_arr[tool_id] = Value::Null;
}
*current_tool_name_sent = false;
if tool_id < streamed_args_for_tool.len() {
streamed_args_for_tool[tool_id].clear();
}
*current_tool_id += 1;
}
}
// Update prev_tool_call_arr with current state
if *current_tool_id >= 0 {
ensure_capacity(*current_tool_id, prev_tool_call_arr, streamed_args_for_tool);
let tool_id = *current_tool_id as usize;
if tool_id < prev_tool_call_arr.len() {
prev_tool_call_arr[tool_id] = current_tool_call;
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ends_with_partial_token() {
assert!(ends_with_partial_token("hello <|py", "<|python_tag|>").is_some());
assert!(ends_with_partial_token("hello <|python_tag", "<|python_tag|>").is_some());
assert!(ends_with_partial_token("hello <|python_tag|>", "<|python_tag|>").is_none());
assert!(ends_with_partial_token("", "<|python_tag|>").is_none());
assert!(ends_with_partial_token("hello world", "<|python_tag|>").is_none());
}
#[test]
fn test_reset_current_tool_state() {
let mut buffer = String::from("partial json");
let mut current_tool_name_sent = true;
let mut streamed_args = vec!["tool0_args".to_string(), "tool1_partial".to_string()];
let prev_tools = vec![serde_json::json!({"name": "tool0"})];
reset_current_tool_state(
&mut buffer,
&mut current_tool_name_sent,
&mut streamed_args,
&prev_tools,
);
assert_eq!(buffer, "");
assert!(!current_tool_name_sent);
assert_eq!(streamed_args.len(), 1); // Popped the partial tool1 args
assert_eq!(streamed_args[0], "tool0_args");
}
#[test]
fn test_reset_current_tool_state_no_pop_when_synced() {
let mut buffer = String::from("partial json");
let mut current_tool_name_sent = true;
let mut streamed_args = vec!["tool0_args".to_string()];
let prev_tools = vec![serde_json::json!({"name": "tool0"})];
reset_current_tool_state(
&mut buffer,
&mut current_tool_name_sent,
&mut streamed_args,
&prev_tools,
);
assert_eq!(buffer, "");
assert!(!current_tool_name_sent);
assert_eq!(streamed_args.len(), 1); // No pop, lengths matched
}
#[test]
fn test_reset_parser_state() {
let mut buffer = String::from("some buffer");
let mut prev_tools = vec![serde_json::json!({"name": "tool0"})];
let mut current_tool_id = 5;
let mut current_tool_name_sent = true;
let mut streamed_args = vec!["args".to_string()];
reset_parser_state(
&mut buffer,
&mut prev_tools,
&mut current_tool_id,
&mut current_tool_name_sent,
&mut streamed_args,
);
assert_eq!(buffer, "");
assert_eq!(prev_tools.len(), 0);
assert_eq!(current_tool_id, 0);
assert!(!current_tool_name_sent);
assert_eq!(streamed_args.len(), 0);
}
#[test]
fn test_ensure_capacity() {
let mut prev_tools = vec![];
let mut streamed_args = vec![];
ensure_capacity(2, &mut prev_tools, &mut streamed_args);
assert_eq!(prev_tools.len(), 3);
assert_eq!(streamed_args.len(), 3);
assert_eq!(prev_tools[0], Value::Null);
assert_eq!(streamed_args[0], "");
}
#[test]
fn test_ensure_capacity_negative_id() {
let mut prev_tools = vec![];
let mut streamed_args = vec![];
ensure_capacity(-1, &mut prev_tools, &mut streamed_args);
// Should not resize for negative ID
assert_eq!(prev_tools.len(), 0);
assert_eq!(streamed_args.len(), 0);
}
#[test]
fn test_is_complete_json() {
assert!(is_complete_json(r#"{"name": "test"}"#));
assert!(is_complete_json("[1, 2, 3]"));
assert!(is_complete_json("42"));
assert!(is_complete_json("true"));
assert!(!is_complete_json(r#"{"name": "#));
assert!(!is_complete_json("[1, 2,"));
}
#[test]
fn test_normalize_arguments_field() {
// Case 1: Has parameters, no arguments
let obj = serde_json::json!({
"name": "test",
"parameters": {"key": "value"}
});
let normalized = normalize_arguments_field(obj);
assert_eq!(
normalized.get("arguments").unwrap(),
&serde_json::json!({"key": "value"})
);
// Case 2: Already has arguments
let obj = serde_json::json!({
"name": "test",
"arguments": {"key": "value"}
});
let normalized = normalize_arguments_field(obj.clone());
assert_eq!(normalized, obj);
// Case 3: No parameters or arguments
let obj = serde_json::json!({"name": "test"});
let normalized = normalize_arguments_field(obj.clone());
assert_eq!(normalized, obj);
}
}
use async_trait::async_trait; use async_trait::async_trait;
use serde_json::Value; use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult}, errors::{ToolParserError, ToolParserResult},
parsers::helpers,
partial_json::PartialJson, partial_json::PartialJson,
state::ParseState,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall}, types::{FunctionCall, StreamingParseResult, ToolCall},
}; };
/// JSON format parser for tool calls /// JSON format parser for tool calls
...@@ -18,6 +20,24 @@ use crate::tool_parser::{ ...@@ -18,6 +20,24 @@ use crate::tool_parser::{
pub struct JsonParser { pub struct JsonParser {
/// Parser for handling incomplete JSON during streaming /// Parser for handling incomplete JSON during streaming
partial_json: PartialJson, partial_json: PartialJson,
/// Buffer for accumulating incomplete patterns across chunks
buffer: String,
/// Stores complete tool call info (name and arguments) for each tool being parsed
prev_tool_call_arr: Vec<Value>,
/// Index of currently streaming tool call (-1 means no active tool)
current_tool_id: i32,
/// Flag for whether current tool's name has been sent to client
current_tool_name_sent: bool,
/// Tracks raw JSON string content streamed to client for each tool's arguments
streamed_args_for_tool: Vec<String>,
/// Separator between multiple tool calls
tool_call_separator: &'static str,
} }
impl JsonParser { impl JsonParser {
...@@ -25,6 +45,12 @@ impl JsonParser { ...@@ -25,6 +45,12 @@ impl JsonParser {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
partial_json: PartialJson::default(), partial_json: PartialJson::default(),
buffer: String::new(),
prev_tool_call_arr: Vec::new(),
current_tool_id: -1,
current_tool_name_sent: false,
streamed_args_for_tool: Vec::new(),
tool_call_separator: ",",
} }
} }
...@@ -158,25 +184,9 @@ impl JsonParser { ...@@ -158,25 +184,9 @@ impl JsonParser {
Ok(tools) Ok(tools)
} }
/// Check if text contains JSON tool call markers (complete markers) /// Check if text contains tool calls
fn has_tool_markers(&self, text: &str) -> bool { fn has_tool_call(&self, text: &str) -> bool {
(text.contains('{') || text.contains('[')) && text.contains("name") text.contains('[') || text.contains('{')
}
/// Check if buffer could be building toward a tool call pattern
fn has_partial_start_token(&self, buffer: &str) -> bool {
// Check if buffer ends with a partial match of tool call patterns
let patterns = [r#"{"name""#, r#"[{"name""#];
for pattern in &patterns {
// Check if buffer ends with any partial of this pattern
for i in 1..=buffer.len().min(pattern.len()) {
if pattern.starts_with(&buffer[buffer.len() - i..]) {
return true;
}
}
}
false
} }
} }
...@@ -206,79 +216,62 @@ impl ToolParser for JsonParser { ...@@ -206,79 +216,62 @@ impl ToolParser for JsonParser {
} }
async fn parse_incremental( async fn parse_incremental(
&self, &mut self,
chunk: &str, chunk: &str,
state: &mut ParseState, tools: &[Tool],
) -> ToolParserResult<StreamResult> { ) -> ToolParserResult<StreamingParseResult> {
state.buffer.push_str(chunk); // Append new text to buffer
let trimmed = state.buffer.trim(); self.buffer.push_str(chunk);
let current_text = &self.buffer.clone();
// If no tool markers and not a partial token, return as normal text │ │
if !self.has_tool_markers(trimmed) && !self.has_partial_start_token(trimmed) { // Check if current_text has tool_call
let normal_text = std::mem::take(&mut state.buffer); let has_tool_start = self.has_tool_call(current_text)
return Ok(StreamResult::NormalText(normal_text)); || (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));
if !has_tool_start {
let normal_text = self.buffer.clone();
self.buffer.clear();
return Ok(StreamingParseResult {
normal_text,
calls: vec![],
});
} }
// Try to parse with partial JSON parser // Build tool indices
match self.partial_json.parse_value(trimmed) { let tool_indices = helpers::get_tool_indices(tools);
Ok((value, consumed)) => {
// Check if we have a complete JSON structure
if consumed == trimmed.len() {
// Check if this is truly complete
let looks_complete = trimmed.ends_with('}') || trimmed.ends_with(']');
if looks_complete {
// Complete JSON, parse tool calls
let tools = self.parse_json_value(&value)?;
if !tools.is_empty() {
// Clear buffer since we consumed everything
state.buffer.clear();
// Return the first tool as complete
// TODO simplified version, address more complex version
if let Some(tool) = tools.into_iter().next() {
return Ok(StreamResult::ToolComplete(tool));
}
}
}
} else {
// Partial JSON, try to extract tool name
if let Some(name) = value.get("name").and_then(|v| v.as_str()) {
// TODO simplified version, address more complex version
// Just return the tool name once we see it
if !state.in_string {
state.in_string = true; // Use as a flag for "name sent"
return Ok(StreamResult::ToolName {
index: 0,
name: name.to_string(),
});
}
// Check for complete arguments // Determine start index for JSON parsing
if let Some(args) = // JSON can start with [ (array) or { (single object)
value.get("arguments").or_else(|| value.get("parameters")) let start_idx = if let Some(bracket_pos) = current_text.find('[') {
{ let brace_pos = current_text.find('{');
if let Ok(args_str) = serde_json::to_string(args) { match brace_pos {
// Return arguments as a single update Some(bp) if bp < bracket_pos => bp,
return Ok(StreamResult::ToolArguments { _ => bracket_pos,
index: 0,
arguments: args_str,
});
}
}
}
}
}
Err(_) => {
// Failed to parse even as partial JSON
// Continue waiting for more data
} }
} } else if let Some(brace_pos) = current_text.find('{') {
brace_pos
Ok(StreamResult::Incomplete) } else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
self.tool_call_separator.len()
} else {
0
};
helpers::handle_json_tool_streaming(
current_text,
start_idx,
&mut self.partial_json,
&tool_indices,
&mut self.buffer,
&mut self.current_tool_id,
&mut self.current_tool_name_sent,
&mut self.streamed_args_for_tool,
&mut self.prev_tool_call_arr,
)
} }
fn detect_format(&self, text: &str) -> bool { fn detect_format(&self, text: &str) -> bool {
self.has_tool_markers(text) let trimmed = text.trim();
(trimmed.starts_with('[') || trimmed.starts_with('{')) && trimmed.contains(r#""name""#)
} }
} }
use async_trait::async_trait; use async_trait::async_trait;
use regex::Regex; use regex::Regex;
use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::ToolParserResult, errors::ToolParserResult,
partial_json::PartialJson, parsers::helpers,
state::ParseState,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall}, types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
}; };
/// Kimi K2 format parser for tool calls /// Kimi K2 format parser for tool calls
...@@ -19,12 +21,32 @@ use crate::tool_parser::{ ...@@ -19,12 +21,32 @@ use crate::tool_parser::{
/// - Function calls with explicit indexing /// - Function calls with explicit indexing
/// - JSON arguments /// - JSON arguments
pub struct KimiK2Parser { pub struct KimiK2Parser {
/// Parser for handling incomplete JSON during streaming
partial_json: PartialJson,
/// Regex for extracting complete tool calls /// Regex for extracting complete tool calls
tool_call_extractor: Regex, tool_call_extractor: Regex,
/// Regex for extracting partial tool calls (streaming) /// Regex for extracting partial tool calls (streaming)
stream_tool_call_extractor: Regex, stream_tool_call_extractor: Regex,
/// Regex pattern for removing completed tool calls from buffer
tool_call_end_pattern: Regex,
/// Robust parser for ids like "functions.search:0" or fallback "search:0"
tool_call_id_regex: Regex,
/// Buffer for accumulating incomplete patterns across chunks
buffer: String,
/// Stores complete tool call info (name and arguments) for each tool being parsed
prev_tool_call_arr: Vec<Value>,
/// Index of currently streaming tool call (-1 means no active tool)
current_tool_id: i32,
/// Flag for whether current tool's name has been sent to client
current_tool_name_sent: bool,
/// Tracks raw JSON string content streamed to client for each tool's arguments
streamed_args_for_tool: Vec<String>,
/// Tracks the last arguments sent for incremental diffing
last_arguments: String,
} }
impl KimiK2Parser { impl KimiK2Parser {
...@@ -38,10 +60,25 @@ impl KimiK2Parser { ...@@ -38,10 +60,25 @@ impl KimiK2Parser {
let stream_pattern = r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*)"; let stream_pattern = r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*)";
let stream_tool_call_extractor = Regex::new(stream_pattern).expect("Valid regex pattern"); let stream_tool_call_extractor = Regex::new(stream_pattern).expect("Valid regex pattern");
// Pattern for removing completed tool calls
let end_pattern = r"<\|tool_call_begin\|>.*?<\|tool_call_end\|>";
let tool_call_end_pattern = Regex::new(end_pattern).expect("Valid regex pattern");
// Robust parser for ids like "functions.search:0" or fallback "search:0"
let id_pattern = r"^(?:functions\.)?(?P<name>[\w\.]+):(?P<index>\d+)$";
let tool_call_id_regex = Regex::new(id_pattern).expect("Valid regex pattern");
Self { Self {
partial_json: PartialJson::default(),
tool_call_extractor, tool_call_extractor,
stream_tool_call_extractor, stream_tool_call_extractor,
tool_call_end_pattern,
tool_call_id_regex,
buffer: String::new(),
prev_tool_call_arr: Vec::new(),
current_tool_id: -1,
current_tool_name_sent: false,
streamed_args_for_tool: Vec::new(),
last_arguments: String::new(),
} }
} }
...@@ -52,22 +89,13 @@ impl KimiK2Parser { ...@@ -52,22 +89,13 @@ impl KimiK2Parser {
/// Parse function ID to extract name and index /// Parse function ID to extract name and index
fn parse_function_id(&self, id: &str) -> Option<(String, usize)> { fn parse_function_id(&self, id: &str) -> Option<(String, usize)> {
// Format: functions.{name}:{index} or namespace.functions.{name}:{index} if let Some(captures) = self.tool_call_id_regex.captures(id) {
// Extract everything after the last dot before the colon as the function name let name = captures.name("name")?.as_str().to_string();
if let Some(colon_pos) = id.rfind(':') { let index = captures.name("index")?.as_str().parse::<usize>().ok()?;
let before_colon = &id[..colon_pos]; Some((name, index))
let index_str = &id[colon_pos + 1..]; } else {
None
// Find the last dot to extract the function name
if let Some(dot_pos) = before_colon.rfind('.') {
let func_name = &before_colon[dot_pos + 1..];
if let Ok(index) = index_str.parse::<usize>() {
return Some((func_name.to_string(), index));
}
}
} }
None
} }
} }
...@@ -140,107 +168,172 @@ impl ToolParser for KimiK2Parser { ...@@ -140,107 +168,172 @@ impl ToolParser for KimiK2Parser {
} }
async fn parse_incremental( async fn parse_incremental(
&self, &mut self,
chunk: &str, chunk: &str,
state: &mut ParseState, tools: &[Tool],
) -> ToolParserResult<StreamResult> { ) -> ToolParserResult<StreamingParseResult> {
state.buffer.push_str(chunk); self.buffer.push_str(chunk);
let current_text = &self.buffer.clone();
// Check for tool markers // Check if we have a tool call (either the start token or individual tool call)
let has_tool_call = let has_tool_call =
self.has_tool_markers(&state.buffer) || state.buffer.contains("<|tool_call_begin|>"); self.has_tool_markers(current_text) || current_text.contains("<|tool_call_begin|>");
if !has_tool_call { if !has_tool_call {
// No tool markers detected - return all buffered content as normal text // No tool markers detected - return all buffered content as normal text
let normal_text = std::mem::take(&mut state.buffer); let mut normal_text = std::mem::take(&mut self.buffer);
return Ok(StreamResult::NormalText(normal_text)); // Remove end tokens if present
for e_token in ["<|tool_calls_section_end|>", "<|tool_call_end|>"] {
normal_text = normal_text.replace(e_token, "");
}
return Ok(StreamingParseResult {
normal_text,
calls: vec![],
});
} }
// Check for text before tool markers and extract it as normal text // Build tool indices for validation
let marker1_pos = state.buffer.find("<|tool_calls_section_begin|>"); let tool_indices = helpers::get_tool_indices(tools);
let marker2_pos = state.buffer.find("<|tool_call_begin|>");
let marker_pos = marker1_pos.iter().chain(marker2_pos.iter()).min().copied();
if let Some(pos) = marker_pos { let mut calls: Vec<ToolCallItem> = Vec::new();
if pos > 0 {
// We have text before the tool marker - extract it as normal text
let normal_text: String = state.buffer.drain(..pos).collect();
return Ok(StreamResult::NormalText(normal_text));
}
}
// Try to match streaming pattern // Try to match streaming pattern
if let Some(captures) = self.stream_tool_call_extractor.captures(&state.buffer) { if let Some(captures) = self.stream_tool_call_extractor.captures(current_text) {
if let (Some(id_match), Some(args_match)) = ( if let (Some(id_match), Some(args_match)) = (
captures.name("tool_call_id"), captures.name("tool_call_id"),
captures.name("function_arguments"), captures.name("function_arguments"),
) { ) {
let function_id = id_match.as_str(); let function_id = id_match.as_str();
let partial_args = args_match.as_str(); let function_args = args_match.as_str();
// Parse function ID // Parse function ID
if let Some((func_name, _index)) = self.parse_function_id(function_id) { if let Some((func_name, _index)) = self.parse_function_id(function_id) {
// Send function name if not sent yet // Validate tool name
if !state.in_string { if !tool_indices.contains_key(&func_name) {
state.in_string = true; // Mark name as sent // Invalid tool name - skip this tool, preserve indexing for next tool
return Ok(StreamResult::ToolName { tracing::warn!("Invalid tool name '{}' - skipping", func_name);
index: 0, helpers::reset_current_tool_state(
name: func_name.clone(), &mut self.buffer,
}); &mut self.current_tool_name_sent,
&mut self.streamed_args_for_tool,
&self.prev_tool_call_arr,
);
return Ok(StreamingParseResult::default());
} }
// Check if we have a complete tool call // Initialize state if this is the first tool call
if let Some(end_pos) = partial_args.find("<|tool_call_end|>") { if self.current_tool_id == -1 {
// Extract just the JSON part self.current_tool_id = 0;
let json_args = &partial_args[..end_pos]; self.prev_tool_call_arr = Vec::new();
self.streamed_args_for_tool = vec![String::new()];
}
// Validate and parse JSON // Ensure we have enough entries in our tracking arrays
if serde_json::from_str::<serde_json::Value>(json_args).is_ok() { helpers::ensure_capacity(
// Generate unique ID self.current_tool_id,
let id = format!("kimi_call_{}", uuid::Uuid::new_v4()); &mut self.prev_tool_call_arr,
&mut self.streamed_args_for_tool,
);
// Send tool name if not sent yet
if !self.current_tool_name_sent {
calls.push(ToolCallItem {
tool_index: self.current_tool_id as usize,
name: Some(func_name.clone()),
parameters: String::new(),
});
self.current_tool_name_sent = true;
let tool = ToolCall { // Store the tool call info for serving layer completions endpoint
id, let tool_id = self.current_tool_id as usize;
r#type: "function".to_string(), if self.prev_tool_call_arr.len() <= tool_id {
function: FunctionCall { self.prev_tool_call_arr
name: func_name, .resize_with(tool_id + 1, || Value::Null);
arguments: json_args.to_string(), }
}, self.prev_tool_call_arr[tool_id] = serde_json::json!({
"name": func_name,
"arguments": {},
});
} else {
// Compute incremental diff
let argument_diff = if function_args.starts_with(&self.last_arguments) {
&function_args[self.last_arguments.len()..]
} else {
function_args
};
// Split by end token before sending (like Python does)
let parsed_args_diff =
if let Some(pos) = argument_diff.find("<|tool_call_end|>") {
&argument_diff[..pos]
} else {
argument_diff
}; };
// Find where this tool call ends in the buffer if !parsed_args_diff.is_empty() {
if let Some(tool_end) = state.buffer.find("<|tool_call_end|>") { calls.push(ToolCallItem {
let end_pos = tool_end + "<|tool_call_end|>".len(); tool_index: self.current_tool_id as usize,
state.buffer.drain(..end_pos); name: None,
parameters: parsed_args_diff.to_string(),
});
// Note: Python adds full diff to _last_arguments, not just parsed part
self.last_arguments.push_str(argument_diff);
let tool_id = self.current_tool_id as usize;
if tool_id < self.streamed_args_for_tool.len() {
self.streamed_args_for_tool[tool_id].push_str(parsed_args_diff);
} }
// Reset state for next tool
state.in_string = false;
return Ok(StreamResult::ToolComplete(tool));
} }
} else {
// Try to parse partial JSON for streaming arguments // Check completeness - split by end token first
match self.partial_json.parse_value(partial_args) { let parsed_args = if let Some(pos) = function_args.find("<|tool_call_end|>")
Ok((value, _consumed)) => { {
let args_str = serde_json::to_string(&value) &function_args[..pos]
.unwrap_or_else(|_| "{}".to_string()); } else {
function_args
return Ok(StreamResult::ToolArguments { };
index: 0,
arguments: args_str, if helpers::is_complete_json(parsed_args) {
}); // Update the stored arguments
if let Ok(parsed_args_value) =
serde_json::from_str::<Value>(parsed_args)
{
let tool_id = self.current_tool_id as usize;
if tool_id < self.prev_tool_call_arr.len() {
if let Some(obj) =
self.prev_tool_call_arr[tool_id].as_object_mut()
{
obj.insert("arguments".to_string(), parsed_args_value);
}
}
} }
Err(_) => {
// Can't parse yet, keep buffering // Find the end of the current tool call and remove only that part from buffer
if let Some(mat) = self.tool_call_end_pattern.find(current_text) {
// Remove the completed tool call from buffer, keep any remaining content
self.buffer = current_text[mat.end()..].to_string();
} else {
self.buffer.clear();
} }
let result = StreamingParseResult {
normal_text: String::new(),
calls,
};
self.current_tool_id += 1;
self.last_arguments.clear();
self.current_tool_name_sent = false;
return Ok(result);
} }
} }
} }
} }
} }
Ok(StreamResult::Incomplete) Ok(StreamingParseResult {
normal_text: String::new(),
calls,
})
} }
fn detect_format(&self, text: &str) -> bool { fn detect_format(&self, text: &str) -> bool {
......
...@@ -2,23 +2,44 @@ use async_trait::async_trait; ...@@ -2,23 +2,44 @@ use async_trait::async_trait;
use serde_json::Value; use serde_json::Value;
use uuid; use uuid;
use crate::protocols::spec::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult}, errors::{ToolParserError, ToolParserResult},
parsers::helpers,
partial_json::PartialJson, partial_json::PartialJson,
state::ParseState,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall}, types::{FunctionCall, StreamingParseResult, ToolCall},
}; };
/// Llama 3.2 format parser for tool calls /// Llama 3.2 format parser for tool calls
/// ///
/// Handles the Llama 3.2 specific format: /// Handles the Llama 3.2 specific format:
/// `<|python_tag|>{"name": "func", "arguments": {...}}` /// `<|python_tag|>{"name": "func", "parameters": {...}}`
/// ///
/// Also supports plain JSON without the python_tag prefix /// Also supports plain JSON without the python_tag prefix
pub struct LlamaParser { pub struct LlamaParser {
/// Parser for handling incomplete JSON during streaming /// Parser for handling incomplete JSON during streaming
partial_json: PartialJson, partial_json: PartialJson,
/// Buffer for accumulating incomplete patterns across chunks
buffer: String,
/// Stores complete tool call info (name and arguments) for each tool being parsed
prev_tool_call_arr: Vec<Value>,
/// Index of currently streaming tool call (-1 means no active tool)
current_tool_id: i32,
/// Flag for whether current tool's name has been sent to client
current_tool_name_sent: bool,
/// Tracks raw JSON string content streamed to client for each tool's arguments
streamed_args_for_tool: Vec<String>,
/// Token configuration
bot_token: &'static str,
tool_call_separator: &'static str,
} }
impl LlamaParser { impl LlamaParser {
...@@ -26,6 +47,13 @@ impl LlamaParser { ...@@ -26,6 +47,13 @@ impl LlamaParser {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
partial_json: PartialJson::default(), partial_json: PartialJson::default(),
buffer: String::new(),
prev_tool_call_arr: Vec::new(),
current_tool_id: -1,
current_tool_name_sent: false,
streamed_args_for_tool: Vec::new(),
bot_token: "<|python_tag|>",
tool_call_separator: ";",
} }
} }
...@@ -76,39 +104,6 @@ impl LlamaParser { ...@@ -76,39 +104,6 @@ impl LlamaParser {
} }
} }
/// Parse JSON value(s) into tool calls
fn parse_json_value(&self, value: &Value) -> ToolParserResult<Vec<ToolCall>> {
let mut tools = Vec::new();
match value {
Value::Array(arr) => {
// Parse each element in the array
for item in arr {
if let Some(tool) = self.parse_single_object(item)? {
tools.push(tool);
}
}
}
Value::Object(_) => {
// Single tool call
if let Some(tool) = self.parse_single_object(value)? {
tools.push(tool);
}
}
_ => {
// Not a valid tool call format
return Ok(vec![]);
}
}
Ok(tools)
}
/// Check if text contains potential tool call markers
fn has_python_tag(&self, text: &str) -> bool {
text.contains("<|python_tag|>")
}
/// Parse semicolon-separated JSON objects /// Parse semicolon-separated JSON objects
fn parse_semicolon_separated(&self, content: &str) -> ToolParserResult<Vec<ToolCall>> { fn parse_semicolon_separated(&self, content: &str) -> ToolParserResult<Vec<ToolCall>> {
let mut all_tools = Vec::new(); let mut all_tools = Vec::new();
...@@ -136,6 +131,11 @@ impl LlamaParser { ...@@ -136,6 +131,11 @@ impl LlamaParser {
Ok(all_tools) Ok(all_tools)
} }
/// Check if text has tool call
fn has_tool_call(&self, text: &str) -> bool {
text.contains("<|python_tag|>") || text.contains('{')
}
} }
impl Default for LlamaParser { impl Default for LlamaParser {
...@@ -185,137 +185,57 @@ impl ToolParser for LlamaParser { ...@@ -185,137 +185,57 @@ impl ToolParser for LlamaParser {
} }
async fn parse_incremental( async fn parse_incremental(
&self, &mut self,
chunk: &str, chunk: &str,
state: &mut ParseState, tools: &[Tool],
) -> ToolParserResult<StreamResult> { ) -> ToolParserResult<StreamingParseResult> {
state.buffer.push_str(chunk); // Append new text to buffer
self.buffer.push_str(chunk);
// In streaming mode, be more lenient - check for potential JSON start let current_text = &self.buffer.clone();
let has_potential_json = state.buffer.contains('{');
let has_tag = self.has_python_tag(&state.buffer); // Check if current_text has tool_call
let has_tool_start = self.has_tool_call(current_text)
// If we have neither python_tag nor potential JSON structure, return as normal text || (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));
if !has_tag && !has_potential_json {
// No relevant markers detected - return all buffered content as normal text if !has_tool_start {
let normal_text = std::mem::take(&mut state.buffer); // Only clear buffer if we're sure no tool call is starting
return Ok(StreamResult::NormalText(normal_text)); if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
} let normal_text = self.buffer.clone();
self.buffer.clear();
// If we only have '{' without more content, wait for more data
let trimmed = state.buffer.trim(); return Ok(StreamingParseResult {
if (trimmed == "{") && !has_tag { normal_text,
return Ok(StreamResult::Incomplete); calls: vec![],
} });
} else {
// Check for text before python_tag and extract it as normal text // Might be partial bot_token, keep buffering
if let Some(tag_pos) = state.buffer.find("<|python_tag|>") { return Ok(StreamingParseResult::default());
if tag_pos > 0 {
// We have text before the python_tag - extract it as normal text
let normal_text: String = state.buffer.drain(..tag_pos).collect();
return Ok(StreamResult::NormalText(normal_text));
}
} else {
// For JSON without python_tag, look for the start of JSON structure
let brace_pos = state.buffer.find('{');
let bracket_pos = state.buffer.find('[');
let json_pos = brace_pos.iter().chain(bracket_pos.iter()).min().copied();
if let Some(pos) = json_pos {
if pos > 0 {
// We have text before JSON structure - extract it as normal text
let normal_text: String = state.buffer.drain(..pos).collect();
return Ok(StreamResult::NormalText(normal_text));
}
} }
} }
// Extract JSON content based on whether we have python_tag // Build tool indices
let (json_content, content_start_pos) = if self.has_python_tag(&state.buffer) { let tool_indices = helpers::get_tool_indices(tools);
// Extract content after python_tag
if let Some(tag_pos) = state.buffer.find("<|python_tag|>") { // Determine start index for JSON parsing
let start = tag_pos + "<|python_tag|>".len(); let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
(&state.buffer[start..], start) pos + self.bot_token.len()
} else { } else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
(&state.buffer[..], 0) self.tool_call_separator.len()
}
} else { } else {
// Find where the actual content starts after trimming 0
let trimmed = state.buffer.trim_start();
let trim_offset = state.buffer.len() - trimmed.len();
(trimmed.trim_end(), trim_offset)
}; };
// Check if we have a semicolon separator (multiple tools) helpers::handle_json_tool_streaming(
if let Some(semicolon_pos) = json_content.find(';') { current_text,
// We have multiple tools - try to parse the first one start_idx,
let first_json = &json_content[..semicolon_pos]; &mut self.partial_json,
&tool_indices,
if let Ok(value) = serde_json::from_str::<Value>(first_json.trim()) { &mut self.buffer,
if let Some(tool) = self.parse_single_object(&value)? { &mut self.current_tool_id,
// Remove the parsed JSON and semicolon from the buffer &mut self.current_tool_name_sent,
let end_pos = content_start_pos + semicolon_pos + 1; // +1 to include the semicolon &mut self.streamed_args_for_tool,
state.buffer.drain(content_start_pos..end_pos); &mut self.prev_tool_call_arr,
)
return Ok(StreamResult::ToolComplete(tool));
}
}
}
// Try to parse with partial JSON parser
match self.partial_json.parse_value(json_content) {
Ok((value, consumed)) => {
// Check if we have a complete JSON structure
if consumed == json_content.len() {
// Check if this is truly complete
let looks_complete = json_content.ends_with('}') || json_content.ends_with(']');
if looks_complete {
// Complete JSON, parse tool calls
let tools = self.parse_json_value(&value)?;
if !tools.is_empty() {
// Clear buffer since we consumed everything
state.buffer.clear();
// Return the first tool as complete
if let Some(tool) = tools.into_iter().next() {
return Ok(StreamResult::ToolComplete(tool));
}
}
}
} else {
// Partial JSON, try to extract tool name for streaming
if let Some(name) = value.get("name").and_then(|v| v.as_str()) {
// Return tool name once we see it
if !state.in_string {
state.in_string = true; // Use as a flag for "name sent"
return Ok(StreamResult::ToolName {
index: 0,
name: name.to_string(),
});
}
// Check for complete arguments
if let Some(args) =
value.get("arguments").or_else(|| value.get("parameters"))
{
if let Ok(args_str) = serde_json::to_string(args) {
return Ok(StreamResult::ToolArguments {
index: 0,
arguments: args_str,
});
}
}
}
}
}
Err(_) => {
// Failed to parse even as partial JSON
// Continue waiting for more data
}
}
Ok(StreamResult::Incomplete)
} }
fn detect_format(&self, text: &str) -> bool { fn detect_format(&self, text: &str) -> bool {
......
use async_trait::async_trait; use async_trait::async_trait;
use serde_json::Value; use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult}, errors::{ToolParserError, ToolParserResult},
parsers::helpers,
partial_json::PartialJson, partial_json::PartialJson,
state::ParseState,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall}, types::{FunctionCall, StreamingParseResult, ToolCall},
}; };
/// Mistral format parser for tool calls /// Mistral format parser for tool calls
...@@ -21,6 +23,25 @@ use crate::tool_parser::{ ...@@ -21,6 +23,25 @@ use crate::tool_parser::{
pub struct MistralParser { pub struct MistralParser {
/// Parser for handling incomplete JSON during streaming /// Parser for handling incomplete JSON during streaming
partial_json: PartialJson, partial_json: PartialJson,
/// Buffer for accumulating incomplete patterns across chunks
buffer: String,
/// Stores complete tool call info (name and arguments) for each tool being parsed
prev_tool_call_arr: Vec<Value>,
/// Index of currently streaming tool call (-1 means no active tool)
current_tool_id: i32,
/// Flag for whether current tool's name has been sent to client
current_tool_name_sent: bool,
/// Tracks raw JSON string content streamed to client for each tool's arguments
streamed_args_for_tool: Vec<String>,
/// Token configuration
bot_token: &'static str,
tool_call_separator: &'static str,
} }
impl MistralParser { impl MistralParser {
...@@ -28,19 +49,16 @@ impl MistralParser { ...@@ -28,19 +49,16 @@ impl MistralParser {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
partial_json: PartialJson::default(), partial_json: PartialJson::default(),
buffer: String::new(),
prev_tool_call_arr: Vec::new(),
current_tool_id: -1,
current_tool_name_sent: false,
streamed_args_for_tool: Vec::new(),
bot_token: "[TOOL_CALLS] [",
tool_call_separator: ", ",
} }
} }
/// Extract JSON array using bracket counting
///
/// Handles nested brackets in JSON content by tracking:
/// - String boundaries (quotes)
/// - Escape sequences
/// - Bracket depth
fn extract_json_array<'a>(&self, text: &'a str) -> Option<&'a str> {
self.extract_json_array_with_pos(text).map(|(_, json)| json)
}
fn extract_json_array_with_pos<'a>(&self, text: &'a str) -> Option<(usize, &'a str)> { fn extract_json_array_with_pos<'a>(&self, text: &'a str) -> Option<(usize, &'a str)> {
const BOT_TOKEN: &str = "[TOOL_CALLS] ["; const BOT_TOKEN: &str = "[TOOL_CALLS] [";
...@@ -100,14 +118,14 @@ impl MistralParser { ...@@ -100,14 +118,14 @@ impl MistralParser {
let mut tools = Vec::new(); let mut tools = Vec::new();
if let Value::Array(arr) = value { if let Value::Array(arr) = value {
for (index, item) in arr.iter().enumerate() { for item in arr.iter() {
if let Some(tool) = self.parse_single_object(item, index)? { if let Some(tool) = self.parse_single_object(item)? {
tools.push(tool); tools.push(tool);
} }
} }
} else { } else {
// Single object case (shouldn't happen with Mistral format, but handle it) // Single object case (shouldn't happen with Mistral format, but handle it)
if let Some(tool) = self.parse_single_object(&value, 0)? { if let Some(tool) = self.parse_single_object(&value)? {
tools.push(tool); tools.push(tool);
} }
} }
...@@ -116,7 +134,7 @@ impl MistralParser { ...@@ -116,7 +134,7 @@ impl MistralParser {
} }
/// Parse a single JSON object into a ToolCall /// Parse a single JSON object into a ToolCall
fn parse_single_object(&self, obj: &Value, index: usize) -> ToolParserResult<Option<ToolCall>> { fn parse_single_object(&self, obj: &Value) -> ToolParserResult<Option<ToolCall>> {
let name = obj.get("name").and_then(|v| v.as_str()); let name = obj.get("name").and_then(|v| v.as_str());
if let Some(name) = name { if let Some(name) = name {
...@@ -128,8 +146,12 @@ impl MistralParser { ...@@ -128,8 +146,12 @@ impl MistralParser {
let arguments = serde_json::to_string(args) let arguments = serde_json::to_string(args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate ID with index for multiple tools // Generate unique ID
let id = format!("mistral_call_{}", index); let id = obj
.get("id")
.and_then(|v| v.as_str())
.map(String::from)
.unwrap_or_else(|| format!("mistral_call_{}", uuid::Uuid::new_v4()));
Ok(Some(ToolCall { Ok(Some(ToolCall {
id, id,
...@@ -188,95 +210,57 @@ impl ToolParser for MistralParser { ...@@ -188,95 +210,57 @@ impl ToolParser for MistralParser {
} }
async fn parse_incremental( async fn parse_incremental(
&self, &mut self,
chunk: &str, chunk: &str,
state: &mut ParseState, tools: &[Tool],
) -> ToolParserResult<StreamResult> { ) -> ToolParserResult<StreamingParseResult> {
state.buffer.push_str(chunk); // Append new text to buffer
self.buffer.push_str(chunk);
// Check if we have the start marker let current_text = &self.buffer.clone();
if !self.has_tool_markers(&state.buffer) {
// No tool markers detected - return all buffered content as normal text // Check if current_text has tool_call
let normal_text = std::mem::take(&mut state.buffer); let has_tool_start = self.has_tool_markers(current_text)
return Ok(StreamResult::NormalText(normal_text)); || (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));
}
if !has_tool_start {
// Check for text before [TOOL_CALLS] and extract it as normal text // Only clear buffer if we're sure no tool call is starting
if let Some(marker_pos) = state.buffer.find("[TOOL_CALLS]") { if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
if marker_pos > 0 { let normal_text = self.buffer.clone();
// We have text before the tool marker - extract it as normal text self.buffer.clear();
let normal_text: String = state.buffer.drain(..marker_pos).collect();
return Ok(StreamResult::NormalText(normal_text)); return Ok(StreamingParseResult {
normal_text,
calls: vec![],
});
} else {
// Might be partial bot_token, keep buffering
return Ok(StreamingParseResult::default());
} }
} }
// Try to extract complete JSON array // Build tool indices
if let Some(json_array) = self.extract_json_array(&state.buffer) { let tool_indices = helpers::get_tool_indices(tools);
// Parse with partial JSON to handle incomplete content
match self.partial_json.parse_value(json_array) {
Ok((value, consumed)) => {
// Check if we have a complete JSON structure
if consumed == json_array.len() {
// Complete JSON, parse tool calls
let tools = if let Value::Array(arr) = value {
let mut result = Vec::new();
for (index, item) in arr.iter().enumerate() {
if let Some(tool) = self.parse_single_object(item, index)? {
result.push(tool);
}
}
result
} else {
vec![]
};
if !tools.is_empty() {
// Clear buffer since we consumed everything
state.buffer.clear();
// Return the first tool (simplified for Phase 3)
// Full multi-tool streaming will be implemented later
if let Some(tool) = tools.into_iter().next() {
return Ok(StreamResult::ToolComplete(tool));
}
}
} else {
// Partial JSON - try to extract tool name for streaming
if let Value::Array(arr) = value {
if let Some(first_tool) = arr.first() {
if let Some(name) = first_tool.get("name").and_then(|v| v.as_str())
{
// Check if we've already sent the name
if !state.in_string {
state.in_string = true; // Use as flag for "name sent"
return Ok(StreamResult::ToolName {
index: 0,
name: name.to_string(),
});
}
// Check for arguments
if let Some(args) = first_tool.get("arguments") {
if let Ok(args_str) = serde_json::to_string(args) {
return Ok(StreamResult::ToolArguments {
index: 0,
arguments: args_str,
});
}
}
}
}
}
}
}
Err(_) => {
// Failed to parse even as partial JSON
// Keep buffering
}
}
}
Ok(StreamResult::Incomplete) // Determine start index for JSON parsing
let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
pos + self.bot_token.len()
} else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
self.tool_call_separator.len()
} else {
0
};
helpers::handle_json_tool_streaming(
current_text,
start_idx,
&mut self.partial_json,
&tool_indices,
&mut self.buffer,
&mut self.current_tool_id,
&mut self.current_tool_name_sent,
&mut self.streamed_args_for_tool,
&mut self.prev_tool_call_arr,
)
} }
fn detect_format(&self, text: &str) -> bool { fn detect_format(&self, text: &str) -> bool {
......
...@@ -15,6 +15,9 @@ pub mod pythonic_parser; ...@@ -15,6 +15,9 @@ pub mod pythonic_parser;
pub mod qwen_parser; pub mod qwen_parser;
pub mod step3_parser; pub mod step3_parser;
// Shared helpers and utilities
pub mod helpers;
// Re-export parser types for convenience // Re-export parser types for convenience
pub use deepseek_parser::DeepSeekParser; pub use deepseek_parser::DeepSeekParser;
pub use glm4_moe_parser::Glm4MoeParser; pub use glm4_moe_parser::Glm4MoeParser;
......
...@@ -15,11 +15,13 @@ use rustpython_parser::{parse, Mode}; ...@@ -15,11 +15,13 @@ use rustpython_parser::{parse, Mode};
use serde_json::{Map, Number, Value}; use serde_json::{Map, Number, Value};
use std::sync::OnceLock; use std::sync::OnceLock;
use crate::protocols::spec::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult}, errors::{ToolParserError, ToolParserResult},
state::ParseState, parsers::helpers,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall}, types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
}; };
static PYTHONIC_BLOCK_REGEX: OnceLock<Regex> = OnceLock::new(); static PYTHONIC_BLOCK_REGEX: OnceLock<Regex> = OnceLock::new();
...@@ -37,13 +39,23 @@ fn pythonic_block_regex() -> &'static Regex { ...@@ -37,13 +39,23 @@ fn pythonic_block_regex() -> &'static Regex {
} }
/// Parser for Pythonic tool call format /// Parser for Pythonic tool call format
#[derive(Default)] pub struct PythonicParser {
pub struct PythonicParser; /// Buffer for accumulating chunks
buffer: String,
}
impl Default for PythonicParser {
fn default() -> Self {
Self::new()
}
}
impl PythonicParser { impl PythonicParser {
/// Create a new Pythonic parser /// Create a new Pythonic parser
pub fn new() -> Self { pub fn new() -> Self {
Self Self {
buffer: String::new(),
}
} }
/// Extract the first pythonic tool call block and return it along with the /// Extract the first pythonic tool call block and return it along with the
...@@ -105,23 +117,90 @@ impl ToolParser for PythonicParser { ...@@ -105,23 +117,90 @@ impl ToolParser for PythonicParser {
} }
async fn parse_incremental( async fn parse_incremental(
&self, &mut self,
chunk: &str, chunk: &str,
state: &mut ParseState, tools: &[Tool],
) -> ToolParserResult<StreamResult> { ) -> ToolParserResult<StreamingParseResult> {
state.buffer.push_str(chunk); self.buffer.push_str(chunk);
let cleaned = Self::strip_special_tokens(&state.buffer); let cleaned = Self::strip_special_tokens(&self.buffer);
if let Some((tool_calls_text, _)) = self.extract_tool_calls(&cleaned) {
if let Ok(tools) = self.parse_tool_call_block(&tool_calls_text) { // Look for opening bracket
if let Some(tool) = tools.into_iter().next() { if let Some(start) = cleaned.find('[') {
state.buffer.clear(); let normal_text = if start > 0 {
return Ok(StreamResult::ToolComplete(tool)); cleaned[..start].to_string()
} else {
String::new()
};
// Look for matching closing bracket
if let Some(end) = find_matching_bracket(&cleaned, start) {
// Found complete tool call - extract it and parse using parse_complete
let call_text = &cleaned[start..=end];
match self.parse_complete(call_text).await {
Ok((_, calls)) => {
// Update buffer with remaining text after tool call
let remaining_text = &cleaned[end + 1..];
self.buffer = remaining_text.to_string();
// Validate tool names and convert ToolCall to ToolCallItem
let tool_indices = helpers::get_tool_indices(tools);
let items: Vec<ToolCallItem> = calls
.into_iter()
.enumerate()
.filter_map(|(idx, tool)| {
if !tool_indices.contains_key(&tool.function.name) {
tracing::warn!(
"Invalid tool name '{}' - skipping",
tool.function.name
);
return None;
}
Some(ToolCallItem {
tool_index: idx,
name: Some(tool.function.name),
parameters: tool.function.arguments,
})
})
.collect();
return Ok(StreamingParseResult {
normal_text,
calls: items,
});
}
Err(e) => {
tracing::warn!("Failed to parse pythonic tool call: {}", e);
// Clear buffer on error
self.buffer.clear();
return Ok(StreamingParseResult::default());
}
} }
} else {
// We have an opening bracket but no closing bracket yet
// Put back everything from the bracket onwards
self.buffer = cleaned[start..].to_string();
if !normal_text.is_empty() {
return Ok(StreamingParseResult {
normal_text,
calls: vec![],
});
}
// Still accumulating a potential tool call
return Ok(StreamingParseResult::default());
} }
} }
Ok(StreamResult::Incomplete) // No tool call bracket found
self.buffer.clear();
Ok(StreamingParseResult {
normal_text: cleaned,
calls: vec![],
})
} }
fn detect_format(&self, text: &str) -> bool { fn detect_format(&self, text: &str) -> bool {
...@@ -134,6 +213,25 @@ impl ToolParser for PythonicParser { ...@@ -134,6 +213,25 @@ impl ToolParser for PythonicParser {
} }
} }
/// Find the matching closing bracket for the opening bracket at start position.
/// Properly handles nested brackets.
fn find_matching_bracket(buffer: &str, start: usize) -> Option<usize> {
let mut bracket_count = 0;
let chars: Vec<char> = buffer.chars().collect();
for (i, &ch) in chars.iter().enumerate().skip(start) {
if ch == '[' {
bracket_count += 1;
} else if ch == ']' {
bracket_count -= 1;
if bracket_count == 0 {
return Some(i);
}
}
}
None // No matching bracket found
}
fn parse_python_expression(source: &str) -> ToolParserResult<Expr> { fn parse_python_expression(source: &str) -> ToolParserResult<Expr> {
let module = parse(source, Mode::Expression, "<pythonic_tool_call>") let module = parse(source, Mode::Expression, "<pythonic_tool_call>")
.map_err(|err| ToolParserError::ParsingFailed(err.to_string()))?; .map_err(|err| ToolParserError::ParsingFailed(err.to_string()))?;
......
...@@ -2,12 +2,14 @@ use async_trait::async_trait; ...@@ -2,12 +2,14 @@ use async_trait::async_trait;
use regex::Regex; use regex::Regex;
use serde_json::Value; use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult}, errors::{ToolParserError, ToolParserResult},
parsers::helpers,
partial_json::PartialJson, partial_json::PartialJson,
state::ParseState,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall}, types::{FunctionCall, StreamingParseResult, ToolCall},
}; };
/// Qwen format parser for tool calls /// Qwen format parser for tool calls
...@@ -19,11 +21,36 @@ use crate::tool_parser::{ ...@@ -19,11 +21,36 @@ use crate::tool_parser::{
/// - XML-style tags with JSON content /// - XML-style tags with JSON content
/// - Support for multiple sequential tool calls /// - Support for multiple sequential tool calls
/// - Newline-aware parsing /// - Newline-aware parsing
/// - Buffering for partial end tokens
pub struct QwenParser { pub struct QwenParser {
/// Parser for handling incomplete JSON during streaming /// Parser for handling incomplete JSON during streaming
partial_json: PartialJson, partial_json: PartialJson,
/// Regex for extracting tool calls
/// Regex for extracting tool calls in parse_complete
extractor: Regex, extractor: Regex,
/// Buffer for accumulating incomplete patterns across chunks
buffer: String,
/// Stores complete tool call info (name and arguments) for each tool being parsed
prev_tool_call_arr: Vec<Value>,
/// Index of currently streaming tool call (-1 means no active tool)
current_tool_id: i32,
/// Flag for whether current tool's name has been sent to client
current_tool_name_sent: bool,
/// Tracks raw JSON string content streamed to client for each tool's arguments
streamed_args_for_tool: Vec<String>,
/// Buffer for normal text that might precede partial end tokens
normal_text_buffer: String,
/// Token configuration
bot_token: &'static str,
eot_token: &'static str,
tool_call_separator: &'static str,
} }
impl QwenParser { impl QwenParser {
...@@ -36,11 +63,20 @@ impl QwenParser { ...@@ -36,11 +63,20 @@ impl QwenParser {
Self { Self {
partial_json: PartialJson::default(), partial_json: PartialJson::default(),
extractor, extractor,
buffer: String::new(),
prev_tool_call_arr: Vec::new(),
current_tool_id: -1,
current_tool_name_sent: false,
streamed_args_for_tool: Vec::new(),
normal_text_buffer: String::new(),
bot_token: "<tool_call>\n",
eot_token: "\n</tool_call>",
tool_call_separator: "\n",
} }
} }
/// Parse a single JSON object into a ToolCall /// Parse a single JSON object into a ToolCall
fn parse_single_object(&self, obj: &Value, index: usize) -> ToolParserResult<Option<ToolCall>> { fn parse_single_object(&self, obj: &Value) -> ToolParserResult<Option<ToolCall>> {
let name = obj.get("name").and_then(|v| v.as_str()); let name = obj.get("name").and_then(|v| v.as_str());
if let Some(name) = name { if let Some(name) = name {
...@@ -52,8 +88,12 @@ impl QwenParser { ...@@ -52,8 +88,12 @@ impl QwenParser {
let arguments = serde_json::to_string(args) let arguments = serde_json::to_string(args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate ID with index for multiple tools // Generate unique ID
let id = format!("qwen_call_{}", index); let id = obj
.get("id")
.and_then(|v| v.as_str())
.map(String::from)
.unwrap_or_else(|| format!("qwen_call_{}", uuid::Uuid::new_v4()));
Ok(Some(ToolCall { Ok(Some(ToolCall {
id, id,
...@@ -73,42 +113,9 @@ impl QwenParser { ...@@ -73,42 +113,9 @@ impl QwenParser {
text.contains("<tool_call>") text.contains("<tool_call>")
} }
/// Find the start position of a tool call /// Check if text has tool call
fn find_tool_start(&self, text: &str) -> Option<usize> { fn has_tool_call(&self, text: &str) -> bool {
text.find("<tool_call>\n") text.contains("<tool_call>")
}
/// Find the end position of a tool call
fn find_tool_end(&self, text: &str, start_pos: usize) -> Option<usize> {
let search_from = start_pos + "<tool_call>\n".len();
text[search_from..]
.find("\n</tool_call>")
.map(|pos| search_from + pos + "\n</tool_call>".len())
}
/// Check if buffer ends with a partial token
fn ends_with_partial_token(&self, buffer: &str) -> Option<usize> {
// Check for partial start token
let start_token = "<tool_call>\n";
// Use inclusive range to check if entire buffer could be a prefix
for i in 1..=start_token.len().min(buffer.len()) {
if start_token.starts_with(&buffer[buffer.len() - i..]) {
return Some(i);
}
}
// Check for partial end token
let end_token = "\n</tool_call>";
// Only check if buffer ends with a partial match (not the complete token without newline)
// If buffer ends with "</tool_call>", that's not a partial token - it's missing the newline
if buffer.ends_with("</tool_call>") {
// This is a complete end tag, just missing the leading newline
// Not a partial token situation
return None;
}
// Use inclusive range to check if entire buffer could be a prefix
(1..=end_token.len().min(buffer.len()))
.find(|&i| end_token.starts_with(&buffer[buffer.len() - i..]))
} }
} }
...@@ -132,17 +139,17 @@ impl ToolParser for QwenParser { ...@@ -132,17 +139,17 @@ impl ToolParser for QwenParser {
// Extract tool calls // Extract tool calls
let mut tools = Vec::new(); let mut tools = Vec::new();
for (index, captures) in self.extractor.captures_iter(text).enumerate() { for captures in self.extractor.captures_iter(text) {
if let Some(json_str) = captures.get(1) { if let Some(json_str) = captures.get(1) {
let parsed = serde_json::from_str::<Value>(json_str.as_str().trim()) let parsed = serde_json::from_str::<Value>(json_str.as_str().trim())
.map_err(|e| ToolParserError::ParsingFailed(e.to_string())) .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))
.and_then(|v| self.parse_single_object(&v, index)); .and_then(|v| self.parse_single_object(&v));
match parsed { match parsed {
Ok(Some(tool)) => tools.push(tool), Ok(Some(tool)) => tools.push(tool),
Ok(None) => continue, Ok(None) => continue,
Err(e) => { Err(e) => {
tracing::warn!("Failed to parse tool call {}: {:?}", index, e); tracing::warn!("Failed to parse tool call: {:?}", e);
continue; continue;
} }
} }
...@@ -158,103 +165,91 @@ impl ToolParser for QwenParser { ...@@ -158,103 +165,91 @@ impl ToolParser for QwenParser {
} }
async fn parse_incremental( async fn parse_incremental(
&self, &mut self,
chunk: &str, chunk: &str,
state: &mut ParseState, tools: &[Tool],
) -> ToolParserResult<StreamResult> { ) -> ToolParserResult<StreamingParseResult> {
state.buffer.push_str(chunk); // Append new text to buffer
self.buffer.push_str(chunk);
// Check for partial token at end of buffer let current_text = &self.buffer.clone();
if let Some(_partial_len) = self.ends_with_partial_token(&state.buffer) {
// Hold back the partial token // Check if current_text has tool_call
return Ok(StreamResult::Incomplete); let has_tool_start = self.has_tool_call(current_text)
} || (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));
// Check if we have the start marker if !has_tool_start {
if !self.has_tool_markers(&state.buffer) { // Only clear buffer if we're sure no tool call is starting
// No tool markers detected - return all buffered content as normal text if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
let normal_text = std::mem::take(&mut state.buffer); let normal_text = self.buffer.clone();
return Ok(StreamResult::NormalText(normal_text)); self.buffer.clear();
}
return Ok(StreamingParseResult {
// Check for text before tool markers and extract it as normal text normal_text,
if let Some(marker_pos) = state.buffer.find("<tool_call>") { calls: vec![],
if marker_pos > 0 { });
// We have text before the tool marker - extract it as normal text } else {
let normal_text: String = state.buffer.drain(..marker_pos).collect(); // Might be partial bot_token, keep buffering
return Ok(StreamResult::NormalText(normal_text)); return Ok(StreamingParseResult::default());
} }
} }
// Find start and end positions // Build tool indices
if let Some(start_pos) = self.find_tool_start(&state.buffer) { let tool_indices = helpers::get_tool_indices(tools);
// Check if we have the complete tool call
if let Some(end_pos) = self.find_tool_end(&state.buffer, start_pos) { // Determine start index for JSON parsing
// Extract the JSON content let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
let json_start = start_pos + "<tool_call>\n".len(); pos + self.bot_token.len()
let json_end = end_pos - "\n</tool_call>".len(); } else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
let json_str = &state.buffer[json_start..json_end]; self.tool_call_separator.len()
} else {
// Parse the complete JSON 0
match serde_json::from_str::<Value>(json_str.trim()) { };
Ok(value) => {
if let Some(tool) = self.parse_single_object(&value, 0)? { let mut result = helpers::handle_json_tool_streaming(
// Clear the consumed part from buffer using drain for efficiency current_text,
state.buffer.drain(..end_pos); start_idx,
return Ok(StreamResult::ToolComplete(tool)); &mut self.partial_json,
} &tool_indices,
} &mut self.buffer,
Err(_) => { &mut self.current_tool_id,
// JSON parsing failed, might be incomplete or malformed &mut self.current_tool_name_sent,
// If we have what looks like a complete tool call block, treat as normal text &mut self.streamed_args_for_tool,
if state.buffer[start_pos..end_pos].contains("\n</tool_call>") { &mut self.prev_tool_call_arr,
let malformed_text: String = state.buffer.drain(..end_pos).collect(); )?;
return Ok(StreamResult::NormalText(malformed_text));
} // Qwen-specific: Handle partial end tokens in normal text
} // After tool calls complete, normal text might contain partial "</tool_call>" tags
} if !result.normal_text.is_empty() {
self.normal_text_buffer.push_str(&result.normal_text);
// Check if buffer contains complete end token (without leading newline)
let end_token_without_newline = &self.eot_token[1..]; // "</tool_call>"
if self.normal_text_buffer.contains(end_token_without_newline) {
// Complete end token found - clean it and return
let cleaned_text = self
.normal_text_buffer
.replace(end_token_without_newline, "");
self.normal_text_buffer.clear();
result.normal_text = cleaned_text;
} else { } else {
// We have start but no end yet - try partial parsing // Check if buffer might contain partial end token at the end
let json_start = start_pos + "<tool_call>\n".len(); if let Some(partial_match_len) = helpers::ends_with_partial_token(
let partial_json = &state.buffer[json_start..]; &self.normal_text_buffer,
end_token_without_newline,
// Remove trailing newline if present (might be start of end token) ) {
let partial_json = partial_json.trim_end(); // Keep potential partial match in buffer, return the rest
let split_point = self.normal_text_buffer.len() - partial_match_len;
// Try to parse with partial JSON parser result.normal_text = self.normal_text_buffer[..split_point].to_string();
match self.partial_json.parse_value(partial_json) { self.normal_text_buffer = self.normal_text_buffer[split_point..].to_string();
Ok((value, _consumed)) => { } else {
// Extract tool name if available // No partial match, return all buffered text
if let Some(name) = value.get("name").and_then(|v| v.as_str()) { result.normal_text = self.normal_text_buffer.clone();
// Check if we've already sent the name self.normal_text_buffer.clear();
if !state.in_string {
state.in_string = true; // Use as flag for "name sent"
return Ok(StreamResult::ToolName {
index: 0,
name: name.to_string(),
});
}
// Check for arguments
if let Some(args) = value.get("arguments") {
if let Ok(args_str) = serde_json::to_string(args) {
return Ok(StreamResult::ToolArguments {
index: 0,
arguments: args_str,
});
}
}
}
}
Err(_) => {
// Failed to parse even as partial JSON
// Keep buffering
}
} }
} }
} }
Ok(StreamResult::Incomplete) Ok(result)
} }
fn detect_format(&self, text: &str) -> bool { fn detect_format(&self, text: &str) -> bool {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment