Unverified Commit b658be6f authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Support tool call parser in streaming (#11160)

parent 5e786cca
......@@ -8,9 +8,9 @@
//! - Different model formats (JSON, Mistral, Qwen, Pythonic, etc.)
use criterion::{black_box, criterion_group, BenchmarkId, Criterion, Throughput};
use sglang_router_rs::tool_parser::{
registry::ParserRegistry, state::ParseState, types::StreamResult,
};
use serde_json::json;
use sglang_router_rs::protocols::spec::{Function, Tool};
use sglang_router_rs::tool_parser::{JsonParser, ToolParser, ToolParserFactory};
use std::collections::BTreeMap;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
......@@ -108,6 +108,40 @@ const STEP3_FORMAT: &str = r#"<step.tML version="0.1">
const GPT_OSS_FORMAT: &str = r#"<Channel.vector_search>{"collection": "technical_documentation", "query_embedding": [0.0234, -0.1456, 0.0891, 0.2341, -0.0567, 0.1234, 0.0456, -0.0789, 0.1567, 0.0234, -0.1123, 0.0678, 0.2345, -0.0456, 0.0891, 0.1234, -0.0567, 0.0789, 0.1456, -0.0234, 0.0891, 0.1567, -0.0678, 0.0345, 0.1234, -0.0456, 0.0789, 0.1891, -0.0234, 0.0567, 0.1345, -0.0891], "top_k": 10, "similarity_metric": "cosine", "filters": {"language": "en", "last_updated": {"$gte": "2023-01-01"}, "categories": {"$in": ["api", "sdk", "integration"]}}, "include_metadata": true, "rerank_with_cross_encoder": true}</Channel.vector_search>"#;
// Create test tools for parsers that need them
fn create_test_tools() -> Vec<Tool> {
vec![
Tool {
tool_type: "function".to_string(),
function: Function {
name: "search".to_string(),
description: Some("Search for information".to_string()),
parameters: json!({
"type": "object",
"properties": {
"query": {"type": "string"},
"limit": {"type": "number"}
}
}),
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "code_interpreter".to_string(),
description: Some("Execute code".to_string()),
parameters: json!({
"type": "object",
"properties": {
"language": {"type": "string"},
"code": {"type": "string"}
}
}),
},
},
]
}
// Large test data for stress testing
fn generate_large_json(num_tools: usize) -> String {
let mut tools = Vec::new();
......@@ -141,7 +175,7 @@ fn bench_registry_creation(c: &mut Criterion) {
b.iter_custom(|iters| {
let start = Instant::now();
for _ in 0..iters {
let registry = black_box(ParserRegistry::new());
let registry = black_box(ToolParserFactory::new());
// Force evaluation to prevent optimization
black_box(registry.list_parsers());
}
......@@ -168,7 +202,7 @@ fn bench_registry_creation(c: &mut Criterion) {
}
fn bench_parser_lookup(c: &mut Criterion) {
let registry = Arc::new(ParserRegistry::new());
let registry = Arc::new(ToolParserFactory::new());
let models = vec![
"gpt-4",
"mistral-large",
......@@ -227,7 +261,7 @@ fn bench_parser_lookup(c: &mut Criterion) {
fn bench_complete_parsing(c: &mut Criterion) {
let rt = Runtime::new().unwrap();
let registry = Arc::new(ParserRegistry::new());
let registry = Arc::new(ToolParserFactory::new());
let test_cases = vec![
("json_simple", "json", JSON_SIMPLE),
......@@ -295,7 +329,6 @@ fn bench_complete_parsing(c: &mut Criterion) {
fn bench_streaming_parsing(c: &mut Criterion) {
let rt = Runtime::new().unwrap();
let registry = Arc::new(ParserRegistry::new());
// Streaming test with chunked input
let chunks = vec![
......@@ -315,24 +348,21 @@ fn bench_streaming_parsing(c: &mut Criterion) {
let printed = Arc::new(AtomicBool::new(false));
group.bench_function("json_streaming", |b| {
let printed_clone = printed.clone();
let registry = registry.clone();
let rt = rt.handle().clone();
b.iter_custom(|iters| {
let parser = registry.get_parser("json").expect("Parser not found");
let tools = create_test_tools();
let start = Instant::now();
for _ in 0..iters {
let parser = parser.clone();
let mut state = ParseState::new();
let mut parser = JsonParser::new();
let mut complete_tools = Vec::new();
rt.block_on(async {
for chunk in &chunks {
if let StreamResult::ToolComplete(tool) =
parser.parse_incremental(chunk, &mut state).await.unwrap()
{
complete_tools.push(tool);
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
if !result.calls.is_empty() {
complete_tools.extend(result.calls);
}
}
});
......@@ -368,7 +398,7 @@ fn bench_streaming_parsing(c: &mut Criterion) {
fn bench_concurrent_parsing(c: &mut Criterion) {
let rt = Runtime::new().unwrap();
let registry = Arc::new(ParserRegistry::new());
let registry = Arc::new(ToolParserFactory::new());
let parser = registry.get_parser("json").expect("Parser not found");
let thread_counts = vec![1, 2, 4, 8, 16, 32];
......@@ -456,7 +486,7 @@ fn bench_concurrent_parsing(c: &mut Criterion) {
fn bench_large_payloads(c: &mut Criterion) {
let rt = Runtime::new().unwrap();
let registry = Arc::new(ParserRegistry::new());
let registry = Arc::new(ToolParserFactory::new());
let parser = registry.get_parser("json").expect("Parser not found");
let sizes = vec![1, 10, 50, 100, 500];
......@@ -526,7 +556,7 @@ fn bench_parser_reuse(c: &mut Criterion) {
b.iter_custom(|iters| {
let start = Instant::now();
for _ in 0..iters {
let registry = ParserRegistry::new();
let registry = ToolParserFactory::new();
let parser = registry.get_parser("json").unwrap();
let result = rt.block_on(async { parser.parse_complete(JSON_SIMPLE).await });
black_box(result.unwrap());
......@@ -552,7 +582,7 @@ fn bench_parser_reuse(c: &mut Criterion) {
// Benchmark reusing registry
let printed_reuse = Arc::new(AtomicBool::new(false));
let shared_registry = Arc::new(ParserRegistry::new());
let shared_registry = Arc::new(ToolParserFactory::new());
group.bench_function("reuse_registry", |b| {
let printed_clone = printed_reuse.clone();
......@@ -627,7 +657,7 @@ fn bench_parser_reuse(c: &mut Criterion) {
fn bench_latency_distribution(c: &mut Criterion) {
let rt = Runtime::new().unwrap();
let registry = Arc::new(ParserRegistry::new());
let registry = Arc::new(ToolParserFactory::new());
let test_cases = vec![
("json", JSON_SIMPLE),
......
......@@ -7,7 +7,7 @@ use crate::policies::PolicyRegistry;
use crate::reasoning_parser::ParserFactory;
use crate::routers::RouterTrait;
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ParserRegistry;
use crate::tool_parser::ToolParserFactory;
use async_trait::async_trait;
use axum::{
body::Body,
......@@ -25,7 +25,7 @@ pub struct GrpcPDRouter {
policy_registry: Arc<PolicyRegistry>,
tokenizer: Arc<dyn Tokenizer>,
reasoning_parser_factory: ParserFactory,
tool_parser_registry: &'static ParserRegistry,
tool_parser_factory: ToolParserFactory,
dp_aware: bool,
api_key: Option<String>,
......@@ -50,9 +50,11 @@ impl GrpcPDRouter {
.as_ref()
.ok_or_else(|| "gRPC PD router requires reasoning parser factory".to_string())?
.clone();
let tool_parser_registry = ctx
.tool_parser_registry
.ok_or_else(|| "gRPC PD router requires tool parser registry".to_string())?;
let tool_parser_factory = ctx
.tool_parser_factory
.as_ref()
.ok_or_else(|| "gRPC PD router requires tool parser factory".to_string())?
.clone();
// Get prefill and decode workers from registry - they should have been created by WorkerManager
let prefill_workers = worker_registry.get_workers_filtered(
......@@ -86,7 +88,7 @@ impl GrpcPDRouter {
policy_registry,
tokenizer,
reasoning_parser_factory,
tool_parser_registry,
tool_parser_factory,
dp_aware: ctx.router_config.dp_aware,
api_key: ctx.router_config.api_key.clone(),
retry_config: ctx.router_config.effective_retry_config(),
......
......@@ -34,7 +34,7 @@ use crate::tokenizer::stop::{
};
use crate::tokenizer::traits::Tokenizer;
use crate::tokenizer::HuggingFaceTokenizer;
use crate::tool_parser::ParserRegistry;
use crate::tool_parser::ToolParserFactory;
use proto::generate_response::Response::{Chunk, Complete, Error};
use serde_json::{json, Map, Value};
use std::time::{Instant, SystemTime, UNIX_EPOCH};
......@@ -56,7 +56,7 @@ pub struct GrpcRouter {
policy_registry: Arc<PolicyRegistry>,
tokenizer: Arc<dyn Tokenizer>,
reasoning_parser_factory: ParserFactory,
tool_parser_registry: &'static ParserRegistry,
tool_parser_factory: ToolParserFactory,
dp_aware: bool,
api_key: Option<String>,
retry_config: RetryConfig,
......@@ -76,9 +76,11 @@ impl GrpcRouter {
.as_ref()
.ok_or_else(|| "gRPC router requires reasoning parser factory".to_string())?
.clone();
let tool_parser_registry = ctx
.tool_parser_registry
.ok_or_else(|| "gRPC router requires tool parser registry".to_string())?;
let tool_parser_factory = ctx
.tool_parser_factory
.as_ref()
.ok_or_else(|| "gRPC router requires tool parser factory".to_string())?
.clone();
let worker_registry = ctx.worker_registry.clone();
let policy_registry = ctx.policy_registry.clone();
......@@ -98,7 +100,7 @@ impl GrpcRouter {
policy_registry,
tokenizer,
reasoning_parser_factory,
tool_parser_registry,
tool_parser_factory,
dp_aware: ctx.router_config.dp_aware,
api_key: ctx.router_config.api_key.clone(),
retry_config: ctx.router_config.effective_retry_config(),
......@@ -779,15 +781,28 @@ impl GrpcRouter {
processed_text: &str,
model: &str,
) -> (Option<Vec<ToolCall>>, String) {
let Some(parser) = self.tool_parser_registry.get_parser(model) else {
return (None, processed_text.to_string());
// Get pooled parser for this model
let pooled_parser = self.tool_parser_factory.get_pooled(model);
// Check format detection first
let can_parse = {
let parser = pooled_parser.lock().await;
parser.detect_format(processed_text)
// Lock is dropped here
};
if !parser.detect_format(processed_text) {
if !can_parse {
return (None, processed_text.to_string());
}
match parser.parse_complete(processed_text).await {
// Lock again for async parsing
let result = {
let parser = pooled_parser.lock().await;
parser.parse_complete(processed_text).await
// Lock is dropped here
};
match result {
Ok((normal_text, parsed_tool_calls)) => {
if parsed_tool_calls.is_empty() {
return (None, normal_text);
......
......@@ -19,7 +19,7 @@ use crate::{
routers::{router_manager::RouterManager, RouterTrait},
service_discovery::{start_service_discovery, ServiceDiscoveryConfig},
tokenizer::{factory as tokenizer_factory, traits::Tokenizer},
tool_parser::ParserRegistry,
tool_parser::ToolParserFactory,
};
use axum::{
extract::{Path, Query, Request, State},
......@@ -46,7 +46,7 @@ pub struct AppContext {
pub rate_limiter: Arc<TokenBucket>,
pub tokenizer: Option<Arc<dyn Tokenizer>>,
pub reasoning_parser_factory: Option<ParserFactory>,
pub tool_parser_registry: Option<&'static ParserRegistry>,
pub tool_parser_factory: Option<ToolParserFactory>,
pub worker_registry: Arc<WorkerRegistry>,
pub policy_registry: Arc<PolicyRegistry>,
pub router_manager: Option<Arc<RouterManager>>,
......@@ -64,7 +64,7 @@ impl AppContext {
let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests);
let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens));
let (tokenizer, reasoning_parser_factory, tool_parser_registry) =
let (tokenizer, reasoning_parser_factory, tool_parser_factory) =
if router_config.connection_mode == ConnectionMode::Grpc {
let tokenizer_path = router_config
.tokenizer_path
......@@ -80,9 +80,9 @@ impl AppContext {
.map_err(|e| format!("Failed to create tokenizer: {e}"))?,
);
let reasoning_parser_factory = Some(ParserFactory::new());
let tool_parser_registry = Some(ParserRegistry::new());
let tool_parser_factory = Some(ToolParserFactory::new());
(tokenizer, reasoning_parser_factory, tool_parser_registry)
(tokenizer, reasoning_parser_factory, tool_parser_factory)
} else {
(None, None, None)
};
......@@ -121,7 +121,7 @@ impl AppContext {
rate_limiter,
tokenizer,
reasoning_parser_factory,
tool_parser_registry,
tool_parser_factory,
worker_registry,
policy_registry,
router_manager,
......
......@@ -539,7 +539,7 @@ mod tests {
)),
tokenizer: None,
reasoning_parser_factory: None,
tool_parser_registry: None,
tool_parser_factory: None,
router_manager: None,
response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()),
load_monitor: None,
......
// Factory and pool for creating model-specific tool parsers with pooling support.
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tokio::sync::Mutex;
use crate::tool_parser::parsers::{
DeepSeekParser, Glm4MoeParser, GptOssHarmonyParser, GptOssParser, JsonParser, KimiK2Parser,
LlamaParser, MistralParser, PythonicParser, QwenParser, Step3Parser,
};
use crate::tool_parser::traits::ToolParser;
/// Type alias for pooled parser instances.
pub type PooledToolParser = Arc<Mutex<Box<dyn ToolParser>>>;
/// Type alias for parser creator functions.
type ParserCreator = Arc<dyn Fn() -> Box<dyn ToolParser> + Send + Sync>;
/// Registry for model-specific tool parsers with pooling support.
#[derive(Clone)]
pub struct ToolParserRegistry {
/// Creator functions for parsers (used when pool is empty)
creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
/// Pooled parser instances for reuse
pool: Arc<RwLock<HashMap<String, PooledToolParser>>>,
/// Model pattern to parser name mappings
model_mapping: Arc<RwLock<HashMap<String, String>>>,
/// Default parser name
default_parser: Arc<RwLock<String>>,
}
impl ToolParserRegistry {
/// Create a new empty registry.
pub fn new() -> Self {
Self {
creators: Arc::new(RwLock::new(HashMap::new())),
pool: Arc::new(RwLock::new(HashMap::new())),
model_mapping: Arc::new(RwLock::new(HashMap::new())),
default_parser: Arc::new(RwLock::new("json".to_string())),
}
}
/// Register a parser creator for a given parser type.
pub fn register_parser<F>(&self, name: &str, creator: F)
where
F: Fn() -> Box<dyn ToolParser> + Send + Sync + 'static,
{
let mut creators = self.creators.write().unwrap();
creators.insert(name.to_string(), Arc::new(creator));
}
/// Map a model name/pattern to a parser
pub fn map_model(&self, model: impl Into<String>, parser: impl Into<String>) {
let mut mapping = self.model_mapping.write().unwrap();
mapping.insert(model.into(), parser.into());
}
/// Get a pooled parser by exact name.
/// Returns a shared parser instance from the pool, creating one if needed.
pub fn get_pooled_parser(&self, name: &str) -> Option<PooledToolParser> {
// First check if we have a pooled instance
{
let pool = self.pool.read().unwrap();
if let Some(parser) = pool.get(name) {
return Some(Arc::clone(parser));
}
}
// If not in pool, create one and add to pool
let creators = self.creators.read().unwrap();
if let Some(creator) = creators.get(name) {
let parser = Arc::new(Mutex::new(creator()));
// Add to pool for future use
let mut pool = self.pool.write().unwrap();
pool.insert(name.to_string(), Arc::clone(&parser));
Some(parser)
} else {
None
}
}
/// Get parser for a specific model
pub fn get_pooled_for_model(&self, model: &str) -> Option<PooledToolParser> {
// Try exact match first
{
let mapping = self.model_mapping.read().unwrap();
if let Some(parser_name) = mapping.get(model) {
if let Some(parser) = self.get_pooled_parser(parser_name) {
return Some(parser);
}
}
}
// Try prefix matching with more specific patterns first
let model_mapping = self.model_mapping.read().unwrap();
let best_match = model_mapping
.iter()
.filter(|(pattern, _)| {
pattern.ends_with('*') && model.starts_with(&pattern[..pattern.len() - 1])
})
.max_by_key(|(pattern, _)| pattern.len());
// Return the best matching parser
if let Some((_, parser_name)) = best_match {
if let Some(parser) = self.get_pooled_parser(parser_name) {
return Some(parser);
}
}
// Fall back to default parser
let default = self.default_parser.read().unwrap().clone();
self.get_pooled_parser(&default)
}
/// Clear the parser pool, forcing new instances to be created.
pub fn clear_pool(&self) {
let mut pool = self.pool.write().unwrap();
pool.clear();
}
/// Set the default parser
pub fn set_default_parser(&self, name: impl Into<String>) {
let mut default = self.default_parser.write().unwrap();
*default = name.into();
}
}
impl Default for ToolParserRegistry {
fn default() -> Self {
Self::new()
}
}
/// Factory for creating tool parsers based on model type.
#[derive(Clone)]
pub struct ToolParserFactory {
registry: ToolParserRegistry,
}
impl ToolParserFactory {
/// Create a new factory with default parsers registered.
pub fn new() -> Self {
let registry = ToolParserRegistry::new();
// Register default parsers
registry.register_parser("json", || Box::new(JsonParser::new()));
registry.register_parser("mistral", || Box::new(MistralParser::new()));
registry.register_parser("qwen", || Box::new(QwenParser::new()));
registry.register_parser("pythonic", || Box::new(PythonicParser::new()));
registry.register_parser("llama", || Box::new(LlamaParser::new()));
registry.register_parser("deepseek", || Box::new(DeepSeekParser::new()));
registry.register_parser("glm4_moe", || Box::new(Glm4MoeParser::new()));
registry.register_parser("step3", || Box::new(Step3Parser::new()));
registry.register_parser("kimik2", || Box::new(KimiK2Parser::new()));
// Register GPT-OSS parsers
registry.register_parser("gpt_oss_legacy", || Box::new(GptOssParser::new()));
registry.register_parser("gpt_oss_harmony", || Box::new(GptOssHarmonyParser::new()));
// Choose which GPT-OSS variant to use as default
if use_harmony_gpt_oss() {
registry.register_parser("gpt_oss", || Box::new(GptOssHarmonyParser::new()));
} else {
registry.register_parser("gpt_oss", || Box::new(GptOssParser::new()));
}
// Register default model mappings
Self::register_default_mappings(&registry);
Self { registry }
}
fn register_default_mappings(registry: &ToolParserRegistry) {
// OpenAI models
registry.map_model("gpt-4*", "json");
registry.map_model("gpt-3.5*", "json");
registry.map_model("gpt-4o*", "json");
// Anthropic models
registry.map_model("claude-*", "json");
// Mistral models
registry.map_model("mistral-*", "mistral");
registry.map_model("mixtral-*", "mistral");
// Qwen models
registry.map_model("qwen*", "qwen");
registry.map_model("Qwen*", "qwen");
// Llama models
registry.map_model("llama-4*", "pythonic");
registry.map_model("meta-llama-4*", "pythonic");
registry.map_model("llama-3.2*", "llama");
registry.map_model("meta-llama-3.2*", "llama");
registry.map_model("llama-*", "json");
registry.map_model("meta-llama-*", "json");
// DeepSeek models
registry.map_model("deepseek-v3*", "deepseek");
registry.map_model("deepseek-ai/DeepSeek-V3*", "deepseek");
registry.map_model("deepseek-*", "pythonic");
// GLM models
registry.map_model("glm-4.5*", "glm4_moe");
registry.map_model("glm-4.6*", "glm4_moe");
registry.map_model("glm-*", "json");
// Step3 models
registry.map_model("step3*", "step3");
registry.map_model("Step-3*", "step3");
// Kimi models
registry.map_model("kimi-k2*", "kimik2");
registry.map_model("Kimi-K2*", "kimik2");
registry.map_model("moonshot*/Kimi-K2*", "kimik2");
// GPT-OSS models
registry.map_model("gpt-oss*", "gpt_oss");
registry.map_model("t4-*", "gpt_oss");
// Other models
registry.map_model("gemini-*", "json");
registry.map_model("palm-*", "json");
registry.map_model("gemma-*", "json");
}
/// Get a pooled parser for the given model ID.
/// Returns a shared instance that can be used concurrently.
/// Falls back to JSON parser if model is not recognized.
pub fn get_pooled(&self, model_id: &str) -> PooledToolParser {
self.registry
.get_pooled_for_model(model_id)
.unwrap_or_else(|| {
// Fallback to JSON parser
self.registry
.get_pooled_parser("json")
.expect("JSON parser should always be registered")
})
}
/// Get the internal registry for custom registration.
pub fn registry(&self) -> &ToolParserRegistry {
&self.registry
}
/// Clear the parser pool.
pub fn clear_pool(&self) {
self.registry.clear_pool();
}
/// Get a non-pooled parser for the given model ID (creates a fresh instance each time).
/// This is useful for benchmarks and testing where you want independent parser instances.
pub fn get_parser(&self, model_id: &str) -> Option<Arc<dyn ToolParser>> {
// Determine which parser type to use
let parser_type = {
let mapping = self.registry.model_mapping.read().unwrap();
// Try exact match first
if let Some(parser_name) = mapping.get(model_id) {
parser_name.clone()
} else {
// Try prefix matching
let best_match = mapping
.iter()
.filter(|(pattern, _)| {
pattern.ends_with('*')
&& model_id.starts_with(&pattern[..pattern.len() - 1])
})
.max_by_key(|(pattern, _)| pattern.len());
if let Some((_, parser_name)) = best_match {
parser_name.clone()
} else {
// Fall back to default
self.registry.default_parser.read().unwrap().clone()
}
}
};
let creators = self.registry.creators.read().unwrap();
creators.get(&parser_type).map(|creator| {
// Call the creator to get a Box<dyn ToolParser>, then convert to Arc
let boxed_parser = creator();
Arc::from(boxed_parser)
})
}
/// List all registered parsers (for compatibility with old API).
pub fn list_parsers(&self) -> Vec<String> {
self.registry
.creators
.read()
.unwrap()
.keys()
.cloned()
.collect()
}
}
impl Default for ToolParserFactory {
fn default() -> Self {
Self::new()
}
}
fn use_harmony_gpt_oss() -> bool {
std::env::var("ROUTER_USE_HARMONY_GPT_OSS")
.ok()
.map(|value| {
let normalized = value.trim();
matches!(
normalized,
"1" | "true" | "TRUE" | "True" | "yes" | "YES" | "Yes" | "on" | "ON" | "On"
)
})
.unwrap_or(false)
}
......@@ -3,8 +3,8 @@
/// This module provides infrastructure for parsing tool calls from various model formats.
// Core modules
pub mod errors;
pub mod factory;
pub mod partial_json;
pub mod registry;
pub mod state;
pub mod traits;
pub mod types;
......@@ -17,10 +17,9 @@ mod tests;
// Re-export commonly used types
pub use errors::{ToolParserError, ToolParserResult};
pub use registry::ParserRegistry;
pub use state::{ParsePhase, ParseState};
pub use factory::{PooledToolParser, ToolParserFactory, ToolParserRegistry};
pub use traits::{PartialJsonParser, ToolParser};
pub use types::{FunctionCall, PartialToolCall, StreamResult, ToolCall};
pub use types::{FunctionCall, PartialToolCall, StreamingParseResult, ToolCall};
// Re-export parsers for convenience
pub use parsers::{
......
......@@ -2,12 +2,13 @@ use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
partial_json::PartialJson,
state::ParseState,
parsers::helpers,
traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall},
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
};
/// DeepSeek V3 format parser for tool calls
......@@ -20,12 +21,29 @@ use crate::tool_parser::{
/// - JSON arguments in code blocks
/// - Support for multiple sequential tool calls
pub struct DeepSeekParser {
/// Parser for handling incomplete JSON during streaming
partial_json: PartialJson,
/// Regex for extracting complete tool calls
tool_call_extractor: Regex,
/// Regex for extracting function details
func_detail_extractor: Regex,
/// Regex for matching partial tool calls during streaming
partial_tool_call_regex: Regex,
/// Regex pattern for removing completed tool calls from buffer
tool_call_end_pattern: Regex,
/// Buffer for accumulating incomplete patterns across chunks
buffer: String,
/// Stores complete tool call info (name and arguments) for each tool being parsed
prev_tool_call_arr: Vec<Value>,
/// Index of currently streaming tool call (-1 means no active tool)
current_tool_id: i32,
/// Flag for whether current tool's name has been sent to client
current_tool_name_sent: bool,
/// Tracks raw JSON string content streamed to client for each tool's arguments
streamed_args_for_tool: Vec<String>,
}
impl DeepSeekParser {
......@@ -38,10 +56,24 @@ impl DeepSeekParser {
let func_detail_pattern = r"(?s)<|tool▁call▁begin|>(.*?)<|tool▁sep|>(.*?)\n```json\n(.*?)\n```<|tool▁call▁end|>";
let func_detail_extractor = Regex::new(func_detail_pattern).expect("Valid regex pattern");
// Partial pattern for streaming - uses .* (greedy) not .*? to match all partial content
let partial_pattern = r"(?s)<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)";
let partial_tool_call_regex = Regex::new(partial_pattern).expect("Valid regex pattern");
// Pattern for removing completed tool calls
let end_pattern = r"(?s)<|tool▁call▁begin|>.*?<|tool▁call▁end|>";
let tool_call_end_pattern = Regex::new(end_pattern).expect("Valid regex pattern");
Self {
partial_json: PartialJson::default(),
tool_call_extractor,
func_detail_extractor,
partial_tool_call_regex,
tool_call_end_pattern,
buffer: String::new(),
prev_tool_call_arr: Vec::new(),
current_tool_id: -1,
current_tool_name_sent: false,
streamed_args_for_tool: Vec::new(),
}
}
......@@ -143,107 +175,146 @@ impl ToolParser for DeepSeekParser {
}
async fn parse_incremental(
&self,
&mut self,
chunk: &str,
state: &mut ParseState,
) -> ToolParserResult<StreamResult> {
state.buffer.push_str(chunk);
tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> {
self.buffer.push_str(chunk);
let current_text = &self.buffer.clone();
// Check if we have a tool call (either the start token or individual tool call)
let has_tool_call =
self.has_tool_markers(current_text) || current_text.contains("<|tool▁call▁begin|>");
// Check for tool markers
if !self.has_tool_markers(&state.buffer) {
if !has_tool_call {
// No tool markers detected - return all buffered content as normal text
let normal_text = std::mem::take(&mut state.buffer);
return Ok(StreamResult::NormalText(normal_text));
// Strip out end tokens if present
let mut normal_text = std::mem::take(&mut self.buffer);
for e_token in ["<|tool▁calls▁end|>", "```", "<|tool▁call▁end|>"] {
normal_text = normal_text.replace(e_token, "");
}
return Ok(StreamingParseResult {
normal_text,
calls: vec![],
});
}
// Check for text before tool markers and extract it as normal text
if let Some(marker_pos) = state.buffer.find("<|tool▁calls▁begin|>") {
if marker_pos > 0 {
// We have text before the tool marker - extract it as normal text
let normal_text: String = state.buffer.drain(..marker_pos).collect();
return Ok(StreamResult::NormalText(normal_text));
// Build tool indices for validation
let tool_indices = helpers::get_tool_indices(tools);
let mut calls: Vec<ToolCallItem> = Vec::new();
// Try to match the partial tool call pattern
if let Some(captures) = self.partial_tool_call_regex.captures(current_text) {
let func_name = captures.get(2).map_or("", |m| m.as_str()).trim();
let func_args_raw = captures.get(3).map_or("", |m| m.as_str()).trim();
// Validate tool name
if !tool_indices.contains_key(func_name) {
// Invalid tool name - skip this tool, preserve indexing for next tool
tracing::warn!("Invalid tool name '{}' - skipping", func_name);
helpers::reset_current_tool_state(
&mut self.buffer,
&mut self.current_tool_name_sent,
&mut self.streamed_args_for_tool,
&self.prev_tool_call_arr,
);
return Ok(StreamingParseResult::default());
}
}
// Look for start of tool calls
if let Some(start_pos) = state.buffer.find("<|tool▁calls▁begin|>") {
// Look for individual tool call start
let search_from = start_pos + "<|tool▁calls▁begin|>".len();
if let Some(call_start) = state.buffer[search_from..].find("<|tool▁call▁begin|>")
{
let call_start_abs = search_from + call_start;
// Look for the end of this tool call
let search_end_from = call_start_abs + "<|tool▁call▁begin|>".len();
if let Some(call_end) = state.buffer[search_end_from..].find("<|tool▁call▁end|>")
{
let call_end_abs = search_end_from + call_end + "<|tool▁call▁end|>".len();
// Extract and parse the complete tool call
let tool_call_text = &state.buffer[call_start_abs..call_end_abs];
match self.parse_tool_call(tool_call_text) {
Ok(tool) => {
// Remove the processed part from buffer
state.buffer.drain(..call_end_abs);
return Ok(StreamResult::ToolComplete(tool));
}
Err(_) => {
// Parsing failed, skip this tool call
state.buffer.drain(..call_end_abs);
}
// Initialize state if this is the first tool call
if self.current_tool_id == -1 {
self.current_tool_id = 0;
self.prev_tool_call_arr = Vec::new();
self.streamed_args_for_tool = vec![String::new()];
}
// Ensure we have enough entries in our tracking arrays
helpers::ensure_capacity(
self.current_tool_id,
&mut self.prev_tool_call_arr,
&mut self.streamed_args_for_tool,
);
// Send tool name if not sent yet
if !self.current_tool_name_sent {
calls.push(ToolCallItem {
tool_index: self.current_tool_id as usize,
name: Some(func_name.to_string()),
parameters: String::new(),
});
self.current_tool_name_sent = true;
// Store the tool call info for serving layer completions endpoint
let tool_id = self.current_tool_id as usize;
if self.prev_tool_call_arr.len() <= tool_id {
self.prev_tool_call_arr
.resize_with(tool_id + 1, || Value::Null);
}
self.prev_tool_call_arr[tool_id] = serde_json::json!({
"name": func_name,
"arguments": {},
});
} else {
// Compute incremental diff
let tool_id = self.current_tool_id as usize;
let last_sent = self
.streamed_args_for_tool
.get(tool_id)
.map(|s| s.as_str())
.unwrap_or("");
let argument_diff = func_args_raw
.strip_prefix(last_sent)
.unwrap_or(func_args_raw);
if !argument_diff.is_empty() {
calls.push(ToolCallItem {
tool_index: tool_id,
name: None,
parameters: argument_diff.to_string(),
});
if tool_id < self.streamed_args_for_tool.len() {
self.streamed_args_for_tool[tool_id].push_str(argument_diff);
}
} else {
// Tool call not complete yet, try to extract partial info
let partial = &state.buffer[search_end_from..];
// Try to extract function name
if let Some(sep_pos) = partial.find("<|tool▁sep|>") {
if let Some(_func_start) = partial[..sep_pos].rfind("function") {
// We have the function type marker
let after_sep = &partial[sep_pos + "<|tool▁sep|>".len()..];
// Look for function name (ends at newline before ```json)
if let Some(name_end) = after_sep.find("\n```json\n") {
let func_name = after_sep[..name_end].trim();
if !state.in_string {
state.in_string = true; // Mark name as sent
return Ok(StreamResult::ToolName {
index: 0,
name: func_name.to_string(),
});
}
// Try to extract partial arguments
let args_start = name_end + "\n```json\n".len();
let partial_args = &after_sep[args_start..];
// Check if we can parse partial JSON
if !partial_args.is_empty() {
match self.partial_json.parse_value(partial_args) {
Ok((value, _consumed)) => {
let args_str = serde_json::to_string(&value)
.unwrap_or_else(|_| "{}".to_string());
return Ok(StreamResult::ToolArguments {
index: 0,
arguments: args_str,
});
}
Err(_) => {
// Can't parse yet, continue waiting for more data
}
}
}
}
// Check if JSON is complete
if helpers::is_complete_json(func_args_raw) {
// Update the stored arguments
if let Ok(parsed_args) = serde_json::from_str::<Value>(func_args_raw) {
let tool_id = self.current_tool_id as usize;
if tool_id < self.prev_tool_call_arr.len() {
if let Some(obj) = self.prev_tool_call_arr[tool_id].as_object_mut() {
obj.insert("arguments".to_string(), parsed_args);
}
}
}
// Find the end of the current tool call and remove only that part from buffer
if let Some(mat) = self.tool_call_end_pattern.find(current_text) {
// Remove the completed tool call from buffer, keep any remaining content
self.buffer = current_text[mat.end()..].to_string();
} else {
self.buffer.clear();
}
let result = StreamingParseResult {
normal_text: String::new(),
calls,
};
self.current_tool_id += 1;
self.current_tool_name_sent = false;
return Ok(result);
}
}
}
Ok(StreamResult::Incomplete)
Ok(StreamingParseResult {
normal_text: String::new(),
calls,
})
}
fn detect_format(&self, text: &str) -> bool {
......
......@@ -2,11 +2,13 @@ use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
state::ParseState,
parsers::helpers,
traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall},
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
};
/// GLM-4 MoE format parser for tool calls
......@@ -25,6 +27,22 @@ pub struct Glm4MoeParser {
func_detail_extractor: Regex,
/// Regex for extracting argument key-value pairs
arg_extractor: Regex,
/// Buffer for accumulating incomplete patterns across chunks
buffer: String,
/// Stores complete tool call info (name and arguments) for each tool being parsed
prev_tool_call_arr: Vec<Value>,
/// Index of currently streaming tool call (-1 means no active tool)
current_tool_id: i32,
/// Tracks raw JSON string content streamed to client for each tool's arguments
streamed_args_for_tool: Vec<String>,
/// Token configuration
bot_token: &'static str,
eot_token: &'static str,
}
impl Glm4MoeParser {
......@@ -44,12 +62,18 @@ impl Glm4MoeParser {
tool_call_extractor,
func_detail_extractor,
arg_extractor,
buffer: String::new(),
prev_tool_call_arr: Vec::new(),
current_tool_id: -1,
streamed_args_for_tool: Vec::new(),
bot_token: "<tool_call>",
eot_token: "</tool_call>",
}
}
/// Check if text contains GLM-4 MoE tool markers
fn has_tool_markers(&self, text: &str) -> bool {
text.contains("<tool_call>")
text.contains(self.bot_token)
}
/// Parse arguments from key-value pairs
......@@ -120,6 +144,25 @@ impl Glm4MoeParser {
Ok(None)
}
}
/// Parse and return StreamingParseResult (mirrors Python's detect_and_parse)
/// Parse all tool calls from text (shared logic for complete and incremental parsing)
fn parse_tool_calls_from_text(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
let mut tools = Vec::new();
for mat in self.tool_call_extractor.find_iter(text) {
match self.parse_tool_call(mat.as_str()) {
Ok(Some(tool)) => tools.push(tool),
Ok(None) => continue,
Err(e) => {
tracing::warn!("Failed to parse tool call: {}", e);
continue;
}
}
}
Ok(tools)
}
}
impl Default for Glm4MoeParser {
......@@ -140,18 +183,8 @@ impl ToolParser for Glm4MoeParser {
let idx = text.find("<tool_call>").unwrap();
let normal_text = text[..idx].to_string();
// Extract tool calls
let mut tools = Vec::new();
for mat in self.tool_call_extractor.find_iter(text) {
match self.parse_tool_call(mat.as_str()) {
Ok(Some(tool)) => tools.push(tool),
Ok(None) => continue,
Err(e) => {
tracing::warn!("Failed to parse tool call: {}", e);
continue;
}
}
}
// Parse all tool calls using shared helper
let tools = self.parse_tool_calls_from_text(text)?;
// If no tools were successfully parsed despite having markers, return entire text as fallback
if tools.is_empty() {
......@@ -162,78 +195,127 @@ impl ToolParser for Glm4MoeParser {
}
async fn parse_incremental(
&self,
&mut self,
chunk: &str,
state: &mut ParseState,
) -> ToolParserResult<StreamResult> {
state.buffer.push_str(chunk);
// Check for tool markers
if !self.has_tool_markers(&state.buffer) {
// No tool markers detected - return all buffered content as normal text
let normal_text = std::mem::take(&mut state.buffer);
return Ok(StreamResult::NormalText(normal_text));
tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> {
// Python logic: Wait for complete tool call, then parse it all at once
self.buffer.push_str(chunk);
let current_text = &self.buffer.clone();
// Check if we have bot_token
let start = current_text.find(self.bot_token);
if start.is_none() {
self.buffer.clear();
// If we're in the middle of streaming (current_tool_id > 0), don't return text
let normal_text = if self.current_tool_id > 0 {
String::new()
} else {
current_text.clone()
};
return Ok(StreamingParseResult {
normal_text,
calls: vec![],
});
}
// Check for text before tool markers and extract it as normal text
if let Some(marker_pos) = state.buffer.find("<tool_call>") {
if marker_pos > 0 {
// We have text before the tool marker - extract it as normal text
let normal_text: String = state.buffer.drain(..marker_pos).collect();
return Ok(StreamResult::NormalText(normal_text));
// Check if we have eot_token (end of tool call)
let end = current_text.find(self.eot_token);
if let Some(end_pos) = end {
// We have a complete tool call!
// Initialize state if this is the first tool call
if self.current_tool_id == -1 {
self.current_tool_id = 0;
self.prev_tool_call_arr = Vec::new();
self.streamed_args_for_tool = vec![String::new()];
}
}
// Look for start of tool call
if let Some(start_pos) = state.buffer.find("<tool_call>") {
// Look for the end of this tool call
let search_from = start_pos + "<tool_call>".len();
if let Some(end_pos) = state.buffer[search_from..].find("</tool_call>") {
let end_abs = search_from + end_pos + "</tool_call>".len();
// Ensure we have enough entries in our tracking arrays
helpers::ensure_capacity(
self.current_tool_id,
&mut self.prev_tool_call_arr,
&mut self.streamed_args_for_tool,
);
// Parse the complete block using shared helper
let block_end = end_pos + self.eot_token.len();
let parsed_tools = self.parse_tool_calls_from_text(&current_text[..block_end])?;
// Extract normal text before tool calls
let idx = current_text.find(self.bot_token);
let normal_text = if let Some(pos) = idx {
current_text[..pos].trim().to_string()
} else {
String::new()
};
// Extract and parse the complete tool call
let tool_call_text = &state.buffer[start_pos..end_abs];
// Build tool indices for validation
let tool_indices = helpers::get_tool_indices(tools);
let mut calls = Vec::new();
if !parsed_tools.is_empty() {
// Take the first tool and convert to ToolCallItem
let tool_call = &parsed_tools[0];
let tool_id = self.current_tool_id as usize;
// Validate tool name
if !tool_indices.contains_key(&tool_call.function.name) {
// Invalid tool name - skip this tool, preserve indexing for next tool
tracing::warn!("Invalid tool name '{}' - skipping", tool_call.function.name);
helpers::reset_current_tool_state(
&mut self.buffer,
&mut false, // glm4_moe doesn't track name_sent per tool
&mut self.streamed_args_for_tool,
&self.prev_tool_call_arr,
);
return Ok(StreamingParseResult::default());
}
if let Some(tool) = self.parse_tool_call(tool_call_text)? {
// Remove the processed part from buffer
state.buffer.drain(..end_abs);
calls.push(ToolCallItem {
tool_index: tool_id,
name: Some(tool_call.function.name.clone()),
parameters: tool_call.function.arguments.clone(),
});
return Ok(StreamResult::ToolComplete(tool));
// Store in tracking arrays
if self.prev_tool_call_arr.len() <= tool_id {
self.prev_tool_call_arr
.resize_with(tool_id + 1, || Value::Null);
}
} else {
// Tool call not complete yet, try to extract partial info
let partial = &state.buffer[search_from..];
// Try to extract function name (first line after <tool_call>)
if let Some(name_end) = partial.find('\n') {
let func_name = partial[..name_end].trim();
if !func_name.is_empty() && !state.in_string {
state.in_string = true; // Mark name as sent
return Ok(StreamResult::ToolName {
index: 0,
name: func_name.to_string(),
});
}
// Try to extract partial arguments
let args_text = &partial[name_end + 1..];
let partial_args = self.parse_arguments(args_text)?;
if !partial_args.is_empty() {
let args_str = serde_json::to_string(&partial_args)
.unwrap_or_else(|_| "{}".to_string());
// Parse parameters as JSON and store
if let Ok(args) = serde_json::from_str::<Value>(&tool_call.function.arguments) {
self.prev_tool_call_arr[tool_id] = serde_json::json!({
"name": tool_call.function.name,
"arguments": args,
});
}
return Ok(StreamResult::ToolArguments {
index: 0,
arguments: args_str,
});
}
if self.streamed_args_for_tool.len() <= tool_id {
self.streamed_args_for_tool
.resize_with(tool_id + 1, String::new);
}
self.streamed_args_for_tool[tool_id] = tool_call.function.arguments.clone();
self.current_tool_id += 1;
}
// Remove processed portion from buffer
self.buffer = current_text[block_end..].to_string();
return Ok(StreamingParseResult { normal_text, calls });
}
Ok(StreamResult::Incomplete)
// No complete tool call yet - return normal text before start token
let start_pos = start.unwrap();
let normal_text = current_text[..start_pos].to_string();
self.buffer = current_text[start_pos..].to_string();
Ok(StreamingParseResult {
normal_text,
calls: vec![],
})
}
fn detect_format(&self, text: &str) -> bool {
......
use async_trait::async_trait;
use crate::protocols::spec::Tool;
use crate::tool_parser::{
errors::ToolParserResult,
state::ParseState,
traits::{TokenToolParser, ToolParser},
types::{StreamResult, ToolCall},
types::{StreamingParseResult, ToolCall},
};
/// Placeholder for the Harmony-backed GPT-OSS parser.
......@@ -29,12 +30,12 @@ impl ToolParser for GptOssHarmonyParser {
}
async fn parse_incremental(
&self,
&mut self,
_chunk: &str,
_state: &mut ParseState,
) -> ToolParserResult<StreamResult> {
_tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> {
// Temporary stub until the Harmony streaming pipeline is implemented.
Ok(StreamResult::Incomplete)
Ok(StreamingParseResult::default())
}
fn detect_format(&self, text: &str) -> bool {
......@@ -61,10 +62,10 @@ impl TokenToolParser for GptOssHarmonyParser {
}
async fn parse_incremental_tokens(
&self,
&mut self,
_tokens: &[u32],
_state: &mut ParseState,
) -> ToolParserResult<StreamResult> {
Ok(StreamResult::Incomplete)
_tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> {
Ok(StreamingParseResult::default())
}
}
......@@ -2,12 +2,14 @@ use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
parsers::helpers,
partial_json::PartialJson,
state::ParseState,
traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall},
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
};
/// GPT-OSS format parser for tool calls
......@@ -26,6 +28,11 @@ pub struct GptOssParser {
function_call_extractor: Regex,
/// Regex for extracting streaming function calls
streaming_extractor: Regex,
/// Buffer for accumulating chunks
buffer: String,
/// Whether the tool name has been sent (for streaming)
name_sent: bool,
}
impl GptOssParser {
......@@ -45,6 +52,9 @@ impl GptOssParser {
partial_json: PartialJson::default(),
function_call_extractor,
streaming_extractor,
buffer: String::new(),
name_sent: false,
}
}
......@@ -123,21 +133,21 @@ impl ToolParser for GptOssParser {
}
async fn parse_incremental(
&self,
&mut self,
chunk: &str,
state: &mut ParseState,
) -> ToolParserResult<StreamResult> {
state.buffer.push_str(chunk);
tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> {
self.buffer.push_str(chunk);
// Check for tool markers
if !self.has_tool_markers(&state.buffer) {
if !self.has_tool_markers(&self.buffer) {
// No markers found, clear buffer and return
state.buffer.clear();
return Ok(StreamResult::Incomplete);
self.buffer.clear();
return Ok(StreamingParseResult::default());
}
// Try to match streaming pattern
if let Some(captures) = self.streaming_extractor.captures(&state.buffer) {
if let Some(captures) = self.streaming_extractor.captures(&self.buffer) {
if let (Some(name_match), Some(args_match)) = (captures.get(1), captures.get(2)) {
let full_function_name = name_match.as_str();
let partial_args = args_match.as_str();
......@@ -146,16 +156,30 @@ impl ToolParser for GptOssParser {
let function_name = self.extract_function_name(full_function_name);
// Send function name if not sent yet
if !state.in_string {
state.in_string = true; // Mark name as sent
return Ok(StreamResult::ToolName {
index: 0,
name: function_name.clone(),
if !self.name_sent {
// Validate tool name
let tool_indices = helpers::get_tool_indices(tools);
if !tool_indices.contains_key(&function_name) {
// Invalid tool name - skip
tracing::warn!("Invalid tool name '{}' - skipping", function_name);
self.buffer.clear();
self.name_sent = false;
return Ok(StreamingParseResult::default());
}
self.name_sent = true; // Mark name as sent
return Ok(StreamingParseResult {
normal_text: String::new(),
calls: vec![ToolCallItem {
tool_index: 0,
name: Some(function_name.clone()),
parameters: String::new(),
}],
});
}
// Check if we have a complete function call
if let Some(complete_match) = self.function_call_extractor.captures(&state.buffer) {
if let Some(complete_match) = self.function_call_extractor.captures(&self.buffer) {
if let Some(args_match) = complete_match.get(2) {
let args_content = args_match.as_str().trim();
......@@ -170,26 +194,22 @@ impl ToolParser for GptOssParser {
}
};
// Generate unique ID
let id = format!("gpt_oss_call_{}", uuid::Uuid::new_v4());
let tool = ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: function_name,
arguments,
},
};
// Remove the processed part from buffer
let complete_end = complete_match.get(0).unwrap().end();
state.buffer.drain(..complete_end);
self.buffer.drain(..complete_end);
// Reset state for next tool
state.in_string = false;
return Ok(StreamResult::ToolComplete(tool));
self.name_sent = false;
// Return final arguments
return Ok(StreamingParseResult {
normal_text: String::new(),
calls: vec![ToolCallItem {
tool_index: 0,
name: None,
parameters: arguments,
}],
});
}
} else {
// Try to parse partial JSON for streaming arguments
......@@ -206,9 +226,13 @@ impl ToolParser for GptOssParser {
let args_str = serde_json::to_string(&value)
.unwrap_or_else(|_| "{}".to_string());
return Ok(StreamResult::ToolArguments {
index: 0,
arguments: args_str,
return Ok(StreamingParseResult {
normal_text: String::new(),
calls: vec![ToolCallItem {
tool_index: 0,
name: None,
parameters: args_str,
}],
});
}
Err(_) => {
......@@ -220,7 +244,7 @@ impl ToolParser for GptOssParser {
}
}
Ok(StreamResult::Incomplete)
Ok(StreamingParseResult::default())
}
fn detect_format(&self, text: &str) -> bool {
......
use crate::protocols::spec::Tool;
use serde_json::Value;
use std::collections::HashMap;
use crate::tool_parser::errors::{ToolParserError, ToolParserResult};
use crate::tool_parser::types::{StreamingParseResult, ToolCallItem};
/// Get a mapping of tool names to their indices
pub fn get_tool_indices(tools: &[Tool]) -> HashMap<String, usize> {
tools
.iter()
.enumerate()
.map(|(i, tool)| (tool.function.name.clone(), i))
.collect()
}
/// Check if a buffer ends with a partial occurrence of a token
/// Returns Some(length) if there's a partial match, None otherwise
pub fn ends_with_partial_token(buffer: &str, token: &str) -> Option<usize> {
if buffer.is_empty() || token.is_empty() {
return None;
}
(1..token.len()).find(|&i| buffer.ends_with(&token[..i]))
}
/// Reset state for the current tool being parsed (used when skipping invalid tools).
/// This preserves the parser's overall state (current_tool_id, prev_tool_call_arr)
/// but clears the state specific to the current incomplete tool.
pub fn reset_current_tool_state(
buffer: &mut String,
current_tool_name_sent: &mut bool,
streamed_args_for_tool: &mut Vec<String>,
prev_tool_call_arr: &[Value],
) {
buffer.clear();
*current_tool_name_sent = false;
// Only pop if we added an entry for the current (invalid) tool
// streamed_args_for_tool should match prev_tool_call_arr length for completed tools
if streamed_args_for_tool.len() > prev_tool_call_arr.len() {
streamed_args_for_tool.pop();
}
}
/// Reset the entire parser state (used at the start of a new request).
/// Clears all accumulated tool calls and resets all state to initial values.
pub fn reset_parser_state(
buffer: &mut String,
prev_tool_call_arr: &mut Vec<Value>,
current_tool_id: &mut i32,
current_tool_name_sent: &mut bool,
streamed_args_for_tool: &mut Vec<String>,
) {
buffer.clear();
prev_tool_call_arr.clear();
*current_tool_id = 0;
*current_tool_name_sent = false;
streamed_args_for_tool.clear();
}
/// Ensure arrays have capacity for the given tool ID
pub fn ensure_capacity(
current_tool_id: i32,
prev_tool_call_arr: &mut Vec<Value>,
streamed_args_for_tool: &mut Vec<String>,
) {
if current_tool_id < 0 {
return;
}
let needed = (current_tool_id + 1) as usize;
if prev_tool_call_arr.len() < needed {
prev_tool_call_arr.resize_with(needed, || Value::Null);
}
if streamed_args_for_tool.len() < needed {
streamed_args_for_tool.resize_with(needed, String::new);
}
}
/// Check if a string contains complete, valid JSON
pub fn is_complete_json(input: &str) -> bool {
serde_json::from_str::<Value>(input).is_ok()
}
/// Normalize the arguments/parameters field in a tool call object.
/// If the object has "parameters" but not "arguments", copy parameters to arguments.
///
/// # Background
/// Different LLM formats use different field names:
/// - Llama and JSON parsers use "parameters" (correct per JSON Schema spec)
/// - Mistral and Qwen use "arguments"
///
/// This function normalizes to "arguments" for consistent downstream processing.
pub fn normalize_arguments_field(mut obj: Value) -> Value {
if obj.get("arguments").is_none() {
if let Some(params) = obj.get("parameters").cloned() {
if let Value::Object(ref mut map) = obj {
map.insert("arguments".to_string(), params);
}
}
}
obj
}
/// Handle the entire JSON tool call streaming process for JSON-based parsers.
///
/// This unified function handles all aspects of streaming tool calls:
/// - Parsing partial JSON from the buffer
/// - Validating tool names against available tools
/// - Streaming tool names (Case 1)
/// - Streaming tool arguments (Case 2)
/// - Managing parser state and buffer updates
///
/// Used by JSON, Llama, Mistral, and Qwen parsers.
///
/// # Parameters
/// - `current_text`: The current buffered text being parsed
/// - `start_idx`: Start index of JSON content in current_text
/// - `partial_json`: Mutable reference to partial JSON parser
/// - `tool_indices`: Map of valid tool names to their indices
/// - `buffer`: Mutable parser buffer
/// - `current_tool_id`: Mutable current tool index (-1 means no active tool)
/// - `current_tool_name_sent`: Mutable flag for whether current tool's name was sent
/// - `streamed_args_for_tool`: Mutable accumulator of streamed arguments per tool
/// - `prev_tool_call_arr`: Mutable array of previous tool call states
///
/// # Returns
/// - `Ok(StreamingParseResult)` with any tool call items to stream
/// - `Err(ToolParserError)` if JSON parsing or serialization fails
#[allow(clippy::too_many_arguments)]
pub fn handle_json_tool_streaming(
current_text: &str,
start_idx: usize,
partial_json: &mut crate::tool_parser::partial_json::PartialJson,
tool_indices: &HashMap<String, usize>,
buffer: &mut String,
current_tool_id: &mut i32,
current_tool_name_sent: &mut bool,
streamed_args_for_tool: &mut Vec<String>,
prev_tool_call_arr: &mut Vec<Value>,
) -> ToolParserResult<StreamingParseResult> {
// Check if we have content to parse
if start_idx >= current_text.len() {
return Ok(StreamingParseResult::default());
}
// Extract JSON string from current position
let json_str = &current_text[start_idx..];
// Parse partial JSON
let (obj, end_idx) = match partial_json.parse_value(json_str) {
Ok(result) => result,
Err(_) => {
return Ok(StreamingParseResult::default());
}
};
// Check if JSON is complete
let is_complete = end_idx == json_str.len() && serde_json::from_str::<Value>(json_str).is_ok();
// Validate tool name if present
if let Some(name) = obj.get("name").and_then(|v| v.as_str()) {
if !tool_indices.contains_key(name) {
// Invalid tool name - skip this tool, preserve indexing for next tool
tracing::warn!("Invalid tool name '{}' - skipping", name);
reset_current_tool_state(
buffer,
current_tool_name_sent,
streamed_args_for_tool,
prev_tool_call_arr,
);
return Ok(StreamingParseResult::default());
}
}
// Normalize parameters/arguments field
let current_tool_call = normalize_arguments_field(obj);
let mut result = StreamingParseResult::default();
// Case 1: Handle tool name streaming
if !*current_tool_name_sent {
if let Some(function_name) = current_tool_call.get("name").and_then(|v| v.as_str()) {
if tool_indices.contains_key(function_name) {
// Initialize if first tool
if *current_tool_id == -1 {
*current_tool_id = 0;
streamed_args_for_tool.push(String::new());
} else if *current_tool_id as usize >= streamed_args_for_tool.len() {
// Ensure capacity for subsequent tools
ensure_capacity(*current_tool_id, prev_tool_call_arr, streamed_args_for_tool);
}
// Send tool name with empty parameters
*current_tool_name_sent = true;
result.calls.push(ToolCallItem {
tool_index: *current_tool_id as usize,
name: Some(function_name.to_string()),
parameters: String::new(),
});
}
}
}
// Case 2: Handle streaming arguments
else if let Some(cur_arguments) = current_tool_call.get("arguments") {
let tool_id = *current_tool_id as usize;
let sent = streamed_args_for_tool
.get(tool_id)
.map(|s| s.len())
.unwrap_or(0);
let cur_args_json = serde_json::to_string(cur_arguments)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Compute diff: everything after what we've already sent
let diff = cur_args_json[sent..].to_string();
// Send diff if there's new content
if !diff.is_empty() {
// Only accumulate if not complete
if !is_complete && tool_id < streamed_args_for_tool.len() {
streamed_args_for_tool[tool_id].push_str(&diff);
}
result.calls.push(ToolCallItem {
tool_index: tool_id,
name: None,
parameters: diff,
});
}
// If JSON is complete, advance to next tool
if is_complete {
// Remove processed portion, keep unprocessed content
*buffer = current_text[start_idx + end_idx..].to_string();
// Clear completed tool data
if tool_id < prev_tool_call_arr.len() {
prev_tool_call_arr[tool_id] = Value::Null;
}
*current_tool_name_sent = false;
if tool_id < streamed_args_for_tool.len() {
streamed_args_for_tool[tool_id].clear();
}
*current_tool_id += 1;
}
}
// Update prev_tool_call_arr with current state
if *current_tool_id >= 0 {
ensure_capacity(*current_tool_id, prev_tool_call_arr, streamed_args_for_tool);
let tool_id = *current_tool_id as usize;
if tool_id < prev_tool_call_arr.len() {
prev_tool_call_arr[tool_id] = current_tool_call;
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ends_with_partial_token() {
assert!(ends_with_partial_token("hello <|py", "<|python_tag|>").is_some());
assert!(ends_with_partial_token("hello <|python_tag", "<|python_tag|>").is_some());
assert!(ends_with_partial_token("hello <|python_tag|>", "<|python_tag|>").is_none());
assert!(ends_with_partial_token("", "<|python_tag|>").is_none());
assert!(ends_with_partial_token("hello world", "<|python_tag|>").is_none());
}
#[test]
fn test_reset_current_tool_state() {
let mut buffer = String::from("partial json");
let mut current_tool_name_sent = true;
let mut streamed_args = vec!["tool0_args".to_string(), "tool1_partial".to_string()];
let prev_tools = vec![serde_json::json!({"name": "tool0"})];
reset_current_tool_state(
&mut buffer,
&mut current_tool_name_sent,
&mut streamed_args,
&prev_tools,
);
assert_eq!(buffer, "");
assert!(!current_tool_name_sent);
assert_eq!(streamed_args.len(), 1); // Popped the partial tool1 args
assert_eq!(streamed_args[0], "tool0_args");
}
#[test]
fn test_reset_current_tool_state_no_pop_when_synced() {
let mut buffer = String::from("partial json");
let mut current_tool_name_sent = true;
let mut streamed_args = vec!["tool0_args".to_string()];
let prev_tools = vec![serde_json::json!({"name": "tool0"})];
reset_current_tool_state(
&mut buffer,
&mut current_tool_name_sent,
&mut streamed_args,
&prev_tools,
);
assert_eq!(buffer, "");
assert!(!current_tool_name_sent);
assert_eq!(streamed_args.len(), 1); // No pop, lengths matched
}
#[test]
fn test_reset_parser_state() {
let mut buffer = String::from("some buffer");
let mut prev_tools = vec![serde_json::json!({"name": "tool0"})];
let mut current_tool_id = 5;
let mut current_tool_name_sent = true;
let mut streamed_args = vec!["args".to_string()];
reset_parser_state(
&mut buffer,
&mut prev_tools,
&mut current_tool_id,
&mut current_tool_name_sent,
&mut streamed_args,
);
assert_eq!(buffer, "");
assert_eq!(prev_tools.len(), 0);
assert_eq!(current_tool_id, 0);
assert!(!current_tool_name_sent);
assert_eq!(streamed_args.len(), 0);
}
#[test]
fn test_ensure_capacity() {
let mut prev_tools = vec![];
let mut streamed_args = vec![];
ensure_capacity(2, &mut prev_tools, &mut streamed_args);
assert_eq!(prev_tools.len(), 3);
assert_eq!(streamed_args.len(), 3);
assert_eq!(prev_tools[0], Value::Null);
assert_eq!(streamed_args[0], "");
}
#[test]
fn test_ensure_capacity_negative_id() {
let mut prev_tools = vec![];
let mut streamed_args = vec![];
ensure_capacity(-1, &mut prev_tools, &mut streamed_args);
// Should not resize for negative ID
assert_eq!(prev_tools.len(), 0);
assert_eq!(streamed_args.len(), 0);
}
#[test]
fn test_is_complete_json() {
assert!(is_complete_json(r#"{"name": "test"}"#));
assert!(is_complete_json("[1, 2, 3]"));
assert!(is_complete_json("42"));
assert!(is_complete_json("true"));
assert!(!is_complete_json(r#"{"name": "#));
assert!(!is_complete_json("[1, 2,"));
}
#[test]
fn test_normalize_arguments_field() {
// Case 1: Has parameters, no arguments
let obj = serde_json::json!({
"name": "test",
"parameters": {"key": "value"}
});
let normalized = normalize_arguments_field(obj);
assert_eq!(
normalized.get("arguments").unwrap(),
&serde_json::json!({"key": "value"})
);
// Case 2: Already has arguments
let obj = serde_json::json!({
"name": "test",
"arguments": {"key": "value"}
});
let normalized = normalize_arguments_field(obj.clone());
assert_eq!(normalized, obj);
// Case 3: No parameters or arguments
let obj = serde_json::json!({"name": "test"});
let normalized = normalize_arguments_field(obj.clone());
assert_eq!(normalized, obj);
}
}
use async_trait::async_trait;
use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
parsers::helpers,
partial_json::PartialJson,
state::ParseState,
traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall},
types::{FunctionCall, StreamingParseResult, ToolCall},
};
/// JSON format parser for tool calls
......@@ -18,6 +20,24 @@ use crate::tool_parser::{
pub struct JsonParser {
/// Parser for handling incomplete JSON during streaming
partial_json: PartialJson,
/// Buffer for accumulating incomplete patterns across chunks
buffer: String,
/// Stores complete tool call info (name and arguments) for each tool being parsed
prev_tool_call_arr: Vec<Value>,
/// Index of currently streaming tool call (-1 means no active tool)
current_tool_id: i32,
/// Flag for whether current tool's name has been sent to client
current_tool_name_sent: bool,
/// Tracks raw JSON string content streamed to client for each tool's arguments
streamed_args_for_tool: Vec<String>,
/// Separator between multiple tool calls
tool_call_separator: &'static str,
}
impl JsonParser {
......@@ -25,6 +45,12 @@ impl JsonParser {
pub fn new() -> Self {
Self {
partial_json: PartialJson::default(),
buffer: String::new(),
prev_tool_call_arr: Vec::new(),
current_tool_id: -1,
current_tool_name_sent: false,
streamed_args_for_tool: Vec::new(),
tool_call_separator: ",",
}
}
......@@ -158,25 +184,9 @@ impl JsonParser {
Ok(tools)
}
/// Check if text contains JSON tool call markers (complete markers)
fn has_tool_markers(&self, text: &str) -> bool {
(text.contains('{') || text.contains('[')) && text.contains("name")
}
/// Check if buffer could be building toward a tool call pattern
fn has_partial_start_token(&self, buffer: &str) -> bool {
// Check if buffer ends with a partial match of tool call patterns
let patterns = [r#"{"name""#, r#"[{"name""#];
for pattern in &patterns {
// Check if buffer ends with any partial of this pattern
for i in 1..=buffer.len().min(pattern.len()) {
if pattern.starts_with(&buffer[buffer.len() - i..]) {
return true;
}
}
}
false
/// Check if text contains tool calls
fn has_tool_call(&self, text: &str) -> bool {
text.contains('[') || text.contains('{')
}
}
......@@ -206,79 +216,62 @@ impl ToolParser for JsonParser {
}
async fn parse_incremental(
&self,
&mut self,
chunk: &str,
state: &mut ParseState,
) -> ToolParserResult<StreamResult> {
state.buffer.push_str(chunk);
let trimmed = state.buffer.trim();
// If no tool markers and not a partial token, return as normal text │ │
if !self.has_tool_markers(trimmed) && !self.has_partial_start_token(trimmed) {
let normal_text = std::mem::take(&mut state.buffer);
return Ok(StreamResult::NormalText(normal_text));
tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> {
// Append new text to buffer
self.buffer.push_str(chunk);
let current_text = &self.buffer.clone();
// Check if current_text has tool_call
let has_tool_start = self.has_tool_call(current_text)
|| (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));
if !has_tool_start {
let normal_text = self.buffer.clone();
self.buffer.clear();
return Ok(StreamingParseResult {
normal_text,
calls: vec![],
});
}
// Try to parse with partial JSON parser
match self.partial_json.parse_value(trimmed) {
Ok((value, consumed)) => {
// Check if we have a complete JSON structure
if consumed == trimmed.len() {
// Check if this is truly complete
let looks_complete = trimmed.ends_with('}') || trimmed.ends_with(']');
if looks_complete {
// Complete JSON, parse tool calls
let tools = self.parse_json_value(&value)?;
if !tools.is_empty() {
// Clear buffer since we consumed everything
state.buffer.clear();
// Return the first tool as complete
// TODO simplified version, address more complex version
if let Some(tool) = tools.into_iter().next() {
return Ok(StreamResult::ToolComplete(tool));
}
}
}
} else {
// Partial JSON, try to extract tool name
if let Some(name) = value.get("name").and_then(|v| v.as_str()) {
// TODO simplified version, address more complex version
// Just return the tool name once we see it
if !state.in_string {
state.in_string = true; // Use as a flag for "name sent"
return Ok(StreamResult::ToolName {
index: 0,
name: name.to_string(),
});
}
// Build tool indices
let tool_indices = helpers::get_tool_indices(tools);
// Check for complete arguments
if let Some(args) =
value.get("arguments").or_else(|| value.get("parameters"))
{
if let Ok(args_str) = serde_json::to_string(args) {
// Return arguments as a single update
return Ok(StreamResult::ToolArguments {
index: 0,
arguments: args_str,
});
}
}
}
}
}
Err(_) => {
// Failed to parse even as partial JSON
// Continue waiting for more data
// Determine start index for JSON parsing
// JSON can start with [ (array) or { (single object)
let start_idx = if let Some(bracket_pos) = current_text.find('[') {
let brace_pos = current_text.find('{');
match brace_pos {
Some(bp) if bp < bracket_pos => bp,
_ => bracket_pos,
}
}
Ok(StreamResult::Incomplete)
} else if let Some(brace_pos) = current_text.find('{') {
brace_pos
} else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
self.tool_call_separator.len()
} else {
0
};
helpers::handle_json_tool_streaming(
current_text,
start_idx,
&mut self.partial_json,
&tool_indices,
&mut self.buffer,
&mut self.current_tool_id,
&mut self.current_tool_name_sent,
&mut self.streamed_args_for_tool,
&mut self.prev_tool_call_arr,
)
}
fn detect_format(&self, text: &str) -> bool {
self.has_tool_markers(text)
let trimmed = text.trim();
(trimmed.starts_with('[') || trimmed.starts_with('{')) && trimmed.contains(r#""name""#)
}
}
use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::tool_parser::{
errors::ToolParserResult,
partial_json::PartialJson,
state::ParseState,
parsers::helpers,
traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall},
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
};
/// Kimi K2 format parser for tool calls
......@@ -19,12 +21,32 @@ use crate::tool_parser::{
/// - Function calls with explicit indexing
/// - JSON arguments
pub struct KimiK2Parser {
/// Parser for handling incomplete JSON during streaming
partial_json: PartialJson,
/// Regex for extracting complete tool calls
tool_call_extractor: Regex,
/// Regex for extracting partial tool calls (streaming)
stream_tool_call_extractor: Regex,
/// Regex pattern for removing completed tool calls from buffer
tool_call_end_pattern: Regex,
/// Robust parser for ids like "functions.search:0" or fallback "search:0"
tool_call_id_regex: Regex,
/// Buffer for accumulating incomplete patterns across chunks
buffer: String,
/// Stores complete tool call info (name and arguments) for each tool being parsed
prev_tool_call_arr: Vec<Value>,
/// Index of currently streaming tool call (-1 means no active tool)
current_tool_id: i32,
/// Flag for whether current tool's name has been sent to client
current_tool_name_sent: bool,
/// Tracks raw JSON string content streamed to client for each tool's arguments
streamed_args_for_tool: Vec<String>,
/// Tracks the last arguments sent for incremental diffing
last_arguments: String,
}
impl KimiK2Parser {
......@@ -38,10 +60,25 @@ impl KimiK2Parser {
let stream_pattern = r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*)";
let stream_tool_call_extractor = Regex::new(stream_pattern).expect("Valid regex pattern");
// Pattern for removing completed tool calls
let end_pattern = r"<\|tool_call_begin\|>.*?<\|tool_call_end\|>";
let tool_call_end_pattern = Regex::new(end_pattern).expect("Valid regex pattern");
// Robust parser for ids like "functions.search:0" or fallback "search:0"
let id_pattern = r"^(?:functions\.)?(?P<name>[\w\.]+):(?P<index>\d+)$";
let tool_call_id_regex = Regex::new(id_pattern).expect("Valid regex pattern");
Self {
partial_json: PartialJson::default(),
tool_call_extractor,
stream_tool_call_extractor,
tool_call_end_pattern,
tool_call_id_regex,
buffer: String::new(),
prev_tool_call_arr: Vec::new(),
current_tool_id: -1,
current_tool_name_sent: false,
streamed_args_for_tool: Vec::new(),
last_arguments: String::new(),
}
}
......@@ -52,22 +89,13 @@ impl KimiK2Parser {
/// Parse function ID to extract name and index
fn parse_function_id(&self, id: &str) -> Option<(String, usize)> {
// Format: functions.{name}:{index} or namespace.functions.{name}:{index}
// Extract everything after the last dot before the colon as the function name
if let Some(colon_pos) = id.rfind(':') {
let before_colon = &id[..colon_pos];
let index_str = &id[colon_pos + 1..];
// Find the last dot to extract the function name
if let Some(dot_pos) = before_colon.rfind('.') {
let func_name = &before_colon[dot_pos + 1..];
if let Ok(index) = index_str.parse::<usize>() {
return Some((func_name.to_string(), index));
}
}
if let Some(captures) = self.tool_call_id_regex.captures(id) {
let name = captures.name("name")?.as_str().to_string();
let index = captures.name("index")?.as_str().parse::<usize>().ok()?;
Some((name, index))
} else {
None
}
None
}
}
......@@ -140,107 +168,172 @@ impl ToolParser for KimiK2Parser {
}
async fn parse_incremental(
&self,
&mut self,
chunk: &str,
state: &mut ParseState,
) -> ToolParserResult<StreamResult> {
state.buffer.push_str(chunk);
tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> {
self.buffer.push_str(chunk);
let current_text = &self.buffer.clone();
// Check for tool markers
// Check if we have a tool call (either the start token or individual tool call)
let has_tool_call =
self.has_tool_markers(&state.buffer) || state.buffer.contains("<|tool_call_begin|>");
self.has_tool_markers(current_text) || current_text.contains("<|tool_call_begin|>");
if !has_tool_call {
// No tool markers detected - return all buffered content as normal text
let normal_text = std::mem::take(&mut state.buffer);
return Ok(StreamResult::NormalText(normal_text));
let mut normal_text = std::mem::take(&mut self.buffer);
// Remove end tokens if present
for e_token in ["<|tool_calls_section_end|>", "<|tool_call_end|>"] {
normal_text = normal_text.replace(e_token, "");
}
return Ok(StreamingParseResult {
normal_text,
calls: vec![],
});
}
// Check for text before tool markers and extract it as normal text
let marker1_pos = state.buffer.find("<|tool_calls_section_begin|>");
let marker2_pos = state.buffer.find("<|tool_call_begin|>");
let marker_pos = marker1_pos.iter().chain(marker2_pos.iter()).min().copied();
// Build tool indices for validation
let tool_indices = helpers::get_tool_indices(tools);
if let Some(pos) = marker_pos {
if pos > 0 {
// We have text before the tool marker - extract it as normal text
let normal_text: String = state.buffer.drain(..pos).collect();
return Ok(StreamResult::NormalText(normal_text));
}
}
let mut calls: Vec<ToolCallItem> = Vec::new();
// Try to match streaming pattern
if let Some(captures) = self.stream_tool_call_extractor.captures(&state.buffer) {
if let Some(captures) = self.stream_tool_call_extractor.captures(current_text) {
if let (Some(id_match), Some(args_match)) = (
captures.name("tool_call_id"),
captures.name("function_arguments"),
) {
let function_id = id_match.as_str();
let partial_args = args_match.as_str();
let function_args = args_match.as_str();
// Parse function ID
if let Some((func_name, _index)) = self.parse_function_id(function_id) {
// Send function name if not sent yet
if !state.in_string {
state.in_string = true; // Mark name as sent
return Ok(StreamResult::ToolName {
index: 0,
name: func_name.clone(),
});
// Validate tool name
if !tool_indices.contains_key(&func_name) {
// Invalid tool name - skip this tool, preserve indexing for next tool
tracing::warn!("Invalid tool name '{}' - skipping", func_name);
helpers::reset_current_tool_state(
&mut self.buffer,
&mut self.current_tool_name_sent,
&mut self.streamed_args_for_tool,
&self.prev_tool_call_arr,
);
return Ok(StreamingParseResult::default());
}
// Check if we have a complete tool call
if let Some(end_pos) = partial_args.find("<|tool_call_end|>") {
// Extract just the JSON part
let json_args = &partial_args[..end_pos];
// Initialize state if this is the first tool call
if self.current_tool_id == -1 {
self.current_tool_id = 0;
self.prev_tool_call_arr = Vec::new();
self.streamed_args_for_tool = vec![String::new()];
}
// Validate and parse JSON
if serde_json::from_str::<serde_json::Value>(json_args).is_ok() {
// Generate unique ID
let id = format!("kimi_call_{}", uuid::Uuid::new_v4());
// Ensure we have enough entries in our tracking arrays
helpers::ensure_capacity(
self.current_tool_id,
&mut self.prev_tool_call_arr,
&mut self.streamed_args_for_tool,
);
// Send tool name if not sent yet
if !self.current_tool_name_sent {
calls.push(ToolCallItem {
tool_index: self.current_tool_id as usize,
name: Some(func_name.clone()),
parameters: String::new(),
});
self.current_tool_name_sent = true;
let tool = ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: func_name,
arguments: json_args.to_string(),
},
// Store the tool call info for serving layer completions endpoint
let tool_id = self.current_tool_id as usize;
if self.prev_tool_call_arr.len() <= tool_id {
self.prev_tool_call_arr
.resize_with(tool_id + 1, || Value::Null);
}
self.prev_tool_call_arr[tool_id] = serde_json::json!({
"name": func_name,
"arguments": {},
});
} else {
// Compute incremental diff
let argument_diff = if function_args.starts_with(&self.last_arguments) {
&function_args[self.last_arguments.len()..]
} else {
function_args
};
// Split by end token before sending (like Python does)
let parsed_args_diff =
if let Some(pos) = argument_diff.find("<|tool_call_end|>") {
&argument_diff[..pos]
} else {
argument_diff
};
// Find where this tool call ends in the buffer
if let Some(tool_end) = state.buffer.find("<|tool_call_end|>") {
let end_pos = tool_end + "<|tool_call_end|>".len();
state.buffer.drain(..end_pos);
if !parsed_args_diff.is_empty() {
calls.push(ToolCallItem {
tool_index: self.current_tool_id as usize,
name: None,
parameters: parsed_args_diff.to_string(),
});
// Note: Python adds full diff to _last_arguments, not just parsed part
self.last_arguments.push_str(argument_diff);
let tool_id = self.current_tool_id as usize;
if tool_id < self.streamed_args_for_tool.len() {
self.streamed_args_for_tool[tool_id].push_str(parsed_args_diff);
}
// Reset state for next tool
state.in_string = false;
return Ok(StreamResult::ToolComplete(tool));
}
} else {
// Try to parse partial JSON for streaming arguments
match self.partial_json.parse_value(partial_args) {
Ok((value, _consumed)) => {
let args_str = serde_json::to_string(&value)
.unwrap_or_else(|_| "{}".to_string());
return Ok(StreamResult::ToolArguments {
index: 0,
arguments: args_str,
});
// Check completeness - split by end token first
let parsed_args = if let Some(pos) = function_args.find("<|tool_call_end|>")
{
&function_args[..pos]
} else {
function_args
};
if helpers::is_complete_json(parsed_args) {
// Update the stored arguments
if let Ok(parsed_args_value) =
serde_json::from_str::<Value>(parsed_args)
{
let tool_id = self.current_tool_id as usize;
if tool_id < self.prev_tool_call_arr.len() {
if let Some(obj) =
self.prev_tool_call_arr[tool_id].as_object_mut()
{
obj.insert("arguments".to_string(), parsed_args_value);
}
}
}
Err(_) => {
// Can't parse yet, keep buffering
// Find the end of the current tool call and remove only that part from buffer
if let Some(mat) = self.tool_call_end_pattern.find(current_text) {
// Remove the completed tool call from buffer, keep any remaining content
self.buffer = current_text[mat.end()..].to_string();
} else {
self.buffer.clear();
}
let result = StreamingParseResult {
normal_text: String::new(),
calls,
};
self.current_tool_id += 1;
self.last_arguments.clear();
self.current_tool_name_sent = false;
return Ok(result);
}
}
}
}
}
Ok(StreamResult::Incomplete)
Ok(StreamingParseResult {
normal_text: String::new(),
calls,
})
}
fn detect_format(&self, text: &str) -> bool {
......
......@@ -2,23 +2,44 @@ use async_trait::async_trait;
use serde_json::Value;
use uuid;
use crate::protocols::spec::Tool;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
parsers::helpers,
partial_json::PartialJson,
state::ParseState,
traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall},
types::{FunctionCall, StreamingParseResult, ToolCall},
};
/// Llama 3.2 format parser for tool calls
///
/// Handles the Llama 3.2 specific format:
/// `<|python_tag|>{"name": "func", "arguments": {...}}`
/// `<|python_tag|>{"name": "func", "parameters": {...}}`
///
/// Also supports plain JSON without the python_tag prefix
pub struct LlamaParser {
/// Parser for handling incomplete JSON during streaming
partial_json: PartialJson,
/// Buffer for accumulating incomplete patterns across chunks
buffer: String,
/// Stores complete tool call info (name and arguments) for each tool being parsed
prev_tool_call_arr: Vec<Value>,
/// Index of currently streaming tool call (-1 means no active tool)
current_tool_id: i32,
/// Flag for whether current tool's name has been sent to client
current_tool_name_sent: bool,
/// Tracks raw JSON string content streamed to client for each tool's arguments
streamed_args_for_tool: Vec<String>,
/// Token configuration
bot_token: &'static str,
tool_call_separator: &'static str,
}
impl LlamaParser {
......@@ -26,6 +47,13 @@ impl LlamaParser {
pub fn new() -> Self {
Self {
partial_json: PartialJson::default(),
buffer: String::new(),
prev_tool_call_arr: Vec::new(),
current_tool_id: -1,
current_tool_name_sent: false,
streamed_args_for_tool: Vec::new(),
bot_token: "<|python_tag|>",
tool_call_separator: ";",
}
}
......@@ -76,39 +104,6 @@ impl LlamaParser {
}
}
/// Parse JSON value(s) into tool calls
fn parse_json_value(&self, value: &Value) -> ToolParserResult<Vec<ToolCall>> {
let mut tools = Vec::new();
match value {
Value::Array(arr) => {
// Parse each element in the array
for item in arr {
if let Some(tool) = self.parse_single_object(item)? {
tools.push(tool);
}
}
}
Value::Object(_) => {
// Single tool call
if let Some(tool) = self.parse_single_object(value)? {
tools.push(tool);
}
}
_ => {
// Not a valid tool call format
return Ok(vec![]);
}
}
Ok(tools)
}
/// Check if text contains potential tool call markers
fn has_python_tag(&self, text: &str) -> bool {
text.contains("<|python_tag|>")
}
/// Parse semicolon-separated JSON objects
fn parse_semicolon_separated(&self, content: &str) -> ToolParserResult<Vec<ToolCall>> {
let mut all_tools = Vec::new();
......@@ -136,6 +131,11 @@ impl LlamaParser {
Ok(all_tools)
}
/// Check if text has tool call
fn has_tool_call(&self, text: &str) -> bool {
text.contains("<|python_tag|>") || text.contains('{')
}
}
impl Default for LlamaParser {
......@@ -185,137 +185,57 @@ impl ToolParser for LlamaParser {
}
async fn parse_incremental(
&self,
&mut self,
chunk: &str,
state: &mut ParseState,
) -> ToolParserResult<StreamResult> {
state.buffer.push_str(chunk);
// In streaming mode, be more lenient - check for potential JSON start
let has_potential_json = state.buffer.contains('{');
let has_tag = self.has_python_tag(&state.buffer);
// If we have neither python_tag nor potential JSON structure, return as normal text
if !has_tag && !has_potential_json {
// No relevant markers detected - return all buffered content as normal text
let normal_text = std::mem::take(&mut state.buffer);
return Ok(StreamResult::NormalText(normal_text));
}
// If we only have '{' without more content, wait for more data
let trimmed = state.buffer.trim();
if (trimmed == "{") && !has_tag {
return Ok(StreamResult::Incomplete);
}
// Check for text before python_tag and extract it as normal text
if let Some(tag_pos) = state.buffer.find("<|python_tag|>") {
if tag_pos > 0 {
// We have text before the python_tag - extract it as normal text
let normal_text: String = state.buffer.drain(..tag_pos).collect();
return Ok(StreamResult::NormalText(normal_text));
}
} else {
// For JSON without python_tag, look for the start of JSON structure
let brace_pos = state.buffer.find('{');
let bracket_pos = state.buffer.find('[');
let json_pos = brace_pos.iter().chain(bracket_pos.iter()).min().copied();
if let Some(pos) = json_pos {
if pos > 0 {
// We have text before JSON structure - extract it as normal text
let normal_text: String = state.buffer.drain(..pos).collect();
return Ok(StreamResult::NormalText(normal_text));
}
tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> {
// Append new text to buffer
self.buffer.push_str(chunk);
let current_text = &self.buffer.clone();
// Check if current_text has tool_call
let has_tool_start = self.has_tool_call(current_text)
|| (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));
if !has_tool_start {
// Only clear buffer if we're sure no tool call is starting
if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
let normal_text = self.buffer.clone();
self.buffer.clear();
return Ok(StreamingParseResult {
normal_text,
calls: vec![],
});
} else {
// Might be partial bot_token, keep buffering
return Ok(StreamingParseResult::default());
}
}
// Extract JSON content based on whether we have python_tag
let (json_content, content_start_pos) = if self.has_python_tag(&state.buffer) {
// Extract content after python_tag
if let Some(tag_pos) = state.buffer.find("<|python_tag|>") {
let start = tag_pos + "<|python_tag|>".len();
(&state.buffer[start..], start)
} else {
(&state.buffer[..], 0)
}
// Build tool indices
let tool_indices = helpers::get_tool_indices(tools);
// Determine start index for JSON parsing
let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
pos + self.bot_token.len()
} else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
self.tool_call_separator.len()
} else {
// Find where the actual content starts after trimming
let trimmed = state.buffer.trim_start();
let trim_offset = state.buffer.len() - trimmed.len();
(trimmed.trim_end(), trim_offset)
0
};
// Check if we have a semicolon separator (multiple tools)
if let Some(semicolon_pos) = json_content.find(';') {
// We have multiple tools - try to parse the first one
let first_json = &json_content[..semicolon_pos];
if let Ok(value) = serde_json::from_str::<Value>(first_json.trim()) {
if let Some(tool) = self.parse_single_object(&value)? {
// Remove the parsed JSON and semicolon from the buffer
let end_pos = content_start_pos + semicolon_pos + 1; // +1 to include the semicolon
state.buffer.drain(content_start_pos..end_pos);
return Ok(StreamResult::ToolComplete(tool));
}
}
}
// Try to parse with partial JSON parser
match self.partial_json.parse_value(json_content) {
Ok((value, consumed)) => {
// Check if we have a complete JSON structure
if consumed == json_content.len() {
// Check if this is truly complete
let looks_complete = json_content.ends_with('}') || json_content.ends_with(']');
if looks_complete {
// Complete JSON, parse tool calls
let tools = self.parse_json_value(&value)?;
if !tools.is_empty() {
// Clear buffer since we consumed everything
state.buffer.clear();
// Return the first tool as complete
if let Some(tool) = tools.into_iter().next() {
return Ok(StreamResult::ToolComplete(tool));
}
}
}
} else {
// Partial JSON, try to extract tool name for streaming
if let Some(name) = value.get("name").and_then(|v| v.as_str()) {
// Return tool name once we see it
if !state.in_string {
state.in_string = true; // Use as a flag for "name sent"
return Ok(StreamResult::ToolName {
index: 0,
name: name.to_string(),
});
}
// Check for complete arguments
if let Some(args) =
value.get("arguments").or_else(|| value.get("parameters"))
{
if let Ok(args_str) = serde_json::to_string(args) {
return Ok(StreamResult::ToolArguments {
index: 0,
arguments: args_str,
});
}
}
}
}
}
Err(_) => {
// Failed to parse even as partial JSON
// Continue waiting for more data
}
}
Ok(StreamResult::Incomplete)
helpers::handle_json_tool_streaming(
current_text,
start_idx,
&mut self.partial_json,
&tool_indices,
&mut self.buffer,
&mut self.current_tool_id,
&mut self.current_tool_name_sent,
&mut self.streamed_args_for_tool,
&mut self.prev_tool_call_arr,
)
}
fn detect_format(&self, text: &str) -> bool {
......
use async_trait::async_trait;
use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
parsers::helpers,
partial_json::PartialJson,
state::ParseState,
traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall},
types::{FunctionCall, StreamingParseResult, ToolCall},
};
/// Mistral format parser for tool calls
......@@ -21,6 +23,25 @@ use crate::tool_parser::{
pub struct MistralParser {
/// Parser for handling incomplete JSON during streaming
partial_json: PartialJson,
/// Buffer for accumulating incomplete patterns across chunks
buffer: String,
/// Stores complete tool call info (name and arguments) for each tool being parsed
prev_tool_call_arr: Vec<Value>,
/// Index of currently streaming tool call (-1 means no active tool)
current_tool_id: i32,
/// Flag for whether current tool's name has been sent to client
current_tool_name_sent: bool,
/// Tracks raw JSON string content streamed to client for each tool's arguments
streamed_args_for_tool: Vec<String>,
/// Token configuration
bot_token: &'static str,
tool_call_separator: &'static str,
}
impl MistralParser {
......@@ -28,19 +49,16 @@ impl MistralParser {
pub fn new() -> Self {
Self {
partial_json: PartialJson::default(),
buffer: String::new(),
prev_tool_call_arr: Vec::new(),
current_tool_id: -1,
current_tool_name_sent: false,
streamed_args_for_tool: Vec::new(),
bot_token: "[TOOL_CALLS] [",
tool_call_separator: ", ",
}
}
/// Extract JSON array using bracket counting
///
/// Handles nested brackets in JSON content by tracking:
/// - String boundaries (quotes)
/// - Escape sequences
/// - Bracket depth
fn extract_json_array<'a>(&self, text: &'a str) -> Option<&'a str> {
self.extract_json_array_with_pos(text).map(|(_, json)| json)
}
fn extract_json_array_with_pos<'a>(&self, text: &'a str) -> Option<(usize, &'a str)> {
const BOT_TOKEN: &str = "[TOOL_CALLS] [";
......@@ -100,14 +118,14 @@ impl MistralParser {
let mut tools = Vec::new();
if let Value::Array(arr) = value {
for (index, item) in arr.iter().enumerate() {
if let Some(tool) = self.parse_single_object(item, index)? {
for item in arr.iter() {
if let Some(tool) = self.parse_single_object(item)? {
tools.push(tool);
}
}
} else {
// Single object case (shouldn't happen with Mistral format, but handle it)
if let Some(tool) = self.parse_single_object(&value, 0)? {
if let Some(tool) = self.parse_single_object(&value)? {
tools.push(tool);
}
}
......@@ -116,7 +134,7 @@ impl MistralParser {
}
/// Parse a single JSON object into a ToolCall
fn parse_single_object(&self, obj: &Value, index: usize) -> ToolParserResult<Option<ToolCall>> {
fn parse_single_object(&self, obj: &Value) -> ToolParserResult<Option<ToolCall>> {
let name = obj.get("name").and_then(|v| v.as_str());
if let Some(name) = name {
......@@ -128,8 +146,12 @@ impl MistralParser {
let arguments = serde_json::to_string(args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate ID with index for multiple tools
let id = format!("mistral_call_{}", index);
// Generate unique ID
let id = obj
.get("id")
.and_then(|v| v.as_str())
.map(String::from)
.unwrap_or_else(|| format!("mistral_call_{}", uuid::Uuid::new_v4()));
Ok(Some(ToolCall {
id,
......@@ -188,95 +210,57 @@ impl ToolParser for MistralParser {
}
async fn parse_incremental(
&self,
&mut self,
chunk: &str,
state: &mut ParseState,
) -> ToolParserResult<StreamResult> {
state.buffer.push_str(chunk);
// Check if we have the start marker
if !self.has_tool_markers(&state.buffer) {
// No tool markers detected - return all buffered content as normal text
let normal_text = std::mem::take(&mut state.buffer);
return Ok(StreamResult::NormalText(normal_text));
}
// Check for text before [TOOL_CALLS] and extract it as normal text
if let Some(marker_pos) = state.buffer.find("[TOOL_CALLS]") {
if marker_pos > 0 {
// We have text before the tool marker - extract it as normal text
let normal_text: String = state.buffer.drain(..marker_pos).collect();
return Ok(StreamResult::NormalText(normal_text));
tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> {
// Append new text to buffer
self.buffer.push_str(chunk);
let current_text = &self.buffer.clone();
// Check if current_text has tool_call
let has_tool_start = self.has_tool_markers(current_text)
|| (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));
if !has_tool_start {
// Only clear buffer if we're sure no tool call is starting
if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
let normal_text = self.buffer.clone();
self.buffer.clear();
return Ok(StreamingParseResult {
normal_text,
calls: vec![],
});
} else {
// Might be partial bot_token, keep buffering
return Ok(StreamingParseResult::default());
}
}
// Try to extract complete JSON array
if let Some(json_array) = self.extract_json_array(&state.buffer) {
// Parse with partial JSON to handle incomplete content
match self.partial_json.parse_value(json_array) {
Ok((value, consumed)) => {
// Check if we have a complete JSON structure
if consumed == json_array.len() {
// Complete JSON, parse tool calls
let tools = if let Value::Array(arr) = value {
let mut result = Vec::new();
for (index, item) in arr.iter().enumerate() {
if let Some(tool) = self.parse_single_object(item, index)? {
result.push(tool);
}
}
result
} else {
vec![]
};
if !tools.is_empty() {
// Clear buffer since we consumed everything
state.buffer.clear();
// Return the first tool (simplified for Phase 3)
// Full multi-tool streaming will be implemented later
if let Some(tool) = tools.into_iter().next() {
return Ok(StreamResult::ToolComplete(tool));
}
}
} else {
// Partial JSON - try to extract tool name for streaming
if let Value::Array(arr) = value {
if let Some(first_tool) = arr.first() {
if let Some(name) = first_tool.get("name").and_then(|v| v.as_str())
{
// Check if we've already sent the name
if !state.in_string {
state.in_string = true; // Use as flag for "name sent"
return Ok(StreamResult::ToolName {
index: 0,
name: name.to_string(),
});
}
// Check for arguments
if let Some(args) = first_tool.get("arguments") {
if let Ok(args_str) = serde_json::to_string(args) {
return Ok(StreamResult::ToolArguments {
index: 0,
arguments: args_str,
});
}
}
}
}
}
}
}
Err(_) => {
// Failed to parse even as partial JSON
// Keep buffering
}
}
}
// Build tool indices
let tool_indices = helpers::get_tool_indices(tools);
Ok(StreamResult::Incomplete)
// Determine start index for JSON parsing
let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
pos + self.bot_token.len()
} else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
self.tool_call_separator.len()
} else {
0
};
helpers::handle_json_tool_streaming(
current_text,
start_idx,
&mut self.partial_json,
&tool_indices,
&mut self.buffer,
&mut self.current_tool_id,
&mut self.current_tool_name_sent,
&mut self.streamed_args_for_tool,
&mut self.prev_tool_call_arr,
)
}
fn detect_format(&self, text: &str) -> bool {
......
......@@ -15,6 +15,9 @@ pub mod pythonic_parser;
pub mod qwen_parser;
pub mod step3_parser;
// Shared helpers and utilities
pub mod helpers;
// Re-export parser types for convenience
pub use deepseek_parser::DeepSeekParser;
pub use glm4_moe_parser::Glm4MoeParser;
......
......@@ -15,11 +15,13 @@ use rustpython_parser::{parse, Mode};
use serde_json::{Map, Number, Value};
use std::sync::OnceLock;
use crate::protocols::spec::Tool;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
state::ParseState,
parsers::helpers,
traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall},
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
};
static PYTHONIC_BLOCK_REGEX: OnceLock<Regex> = OnceLock::new();
......@@ -37,13 +39,23 @@ fn pythonic_block_regex() -> &'static Regex {
}
/// Parser for Pythonic tool call format
#[derive(Default)]
pub struct PythonicParser;
pub struct PythonicParser {
/// Buffer for accumulating chunks
buffer: String,
}
impl Default for PythonicParser {
fn default() -> Self {
Self::new()
}
}
impl PythonicParser {
/// Create a new Pythonic parser
pub fn new() -> Self {
Self
Self {
buffer: String::new(),
}
}
/// Extract the first pythonic tool call block and return it along with the
......@@ -105,23 +117,90 @@ impl ToolParser for PythonicParser {
}
async fn parse_incremental(
&self,
&mut self,
chunk: &str,
state: &mut ParseState,
) -> ToolParserResult<StreamResult> {
state.buffer.push_str(chunk);
let cleaned = Self::strip_special_tokens(&state.buffer);
if let Some((tool_calls_text, _)) = self.extract_tool_calls(&cleaned) {
if let Ok(tools) = self.parse_tool_call_block(&tool_calls_text) {
if let Some(tool) = tools.into_iter().next() {
state.buffer.clear();
return Ok(StreamResult::ToolComplete(tool));
tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> {
self.buffer.push_str(chunk);
let cleaned = Self::strip_special_tokens(&self.buffer);
// Look for opening bracket
if let Some(start) = cleaned.find('[') {
let normal_text = if start > 0 {
cleaned[..start].to_string()
} else {
String::new()
};
// Look for matching closing bracket
if let Some(end) = find_matching_bracket(&cleaned, start) {
// Found complete tool call - extract it and parse using parse_complete
let call_text = &cleaned[start..=end];
match self.parse_complete(call_text).await {
Ok((_, calls)) => {
// Update buffer with remaining text after tool call
let remaining_text = &cleaned[end + 1..];
self.buffer = remaining_text.to_string();
// Validate tool names and convert ToolCall to ToolCallItem
let tool_indices = helpers::get_tool_indices(tools);
let items: Vec<ToolCallItem> = calls
.into_iter()
.enumerate()
.filter_map(|(idx, tool)| {
if !tool_indices.contains_key(&tool.function.name) {
tracing::warn!(
"Invalid tool name '{}' - skipping",
tool.function.name
);
return None;
}
Some(ToolCallItem {
tool_index: idx,
name: Some(tool.function.name),
parameters: tool.function.arguments,
})
})
.collect();
return Ok(StreamingParseResult {
normal_text,
calls: items,
});
}
Err(e) => {
tracing::warn!("Failed to parse pythonic tool call: {}", e);
// Clear buffer on error
self.buffer.clear();
return Ok(StreamingParseResult::default());
}
}
} else {
// We have an opening bracket but no closing bracket yet
// Put back everything from the bracket onwards
self.buffer = cleaned[start..].to_string();
if !normal_text.is_empty() {
return Ok(StreamingParseResult {
normal_text,
calls: vec![],
});
}
// Still accumulating a potential tool call
return Ok(StreamingParseResult::default());
}
}
Ok(StreamResult::Incomplete)
// No tool call bracket found
self.buffer.clear();
Ok(StreamingParseResult {
normal_text: cleaned,
calls: vec![],
})
}
fn detect_format(&self, text: &str) -> bool {
......@@ -134,6 +213,25 @@ impl ToolParser for PythonicParser {
}
}
/// Find the matching closing bracket for the opening bracket at start position.
/// Properly handles nested brackets.
fn find_matching_bracket(buffer: &str, start: usize) -> Option<usize> {
let mut bracket_count = 0;
let chars: Vec<char> = buffer.chars().collect();
for (i, &ch) in chars.iter().enumerate().skip(start) {
if ch == '[' {
bracket_count += 1;
} else if ch == ']' {
bracket_count -= 1;
if bracket_count == 0 {
return Some(i);
}
}
}
None // No matching bracket found
}
fn parse_python_expression(source: &str) -> ToolParserResult<Expr> {
let module = parse(source, Mode::Expression, "<pythonic_tool_call>")
.map_err(|err| ToolParserError::ParsingFailed(err.to_string()))?;
......
......@@ -2,12 +2,14 @@ use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
parsers::helpers,
partial_json::PartialJson,
state::ParseState,
traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall},
types::{FunctionCall, StreamingParseResult, ToolCall},
};
/// Qwen format parser for tool calls
......@@ -19,11 +21,36 @@ use crate::tool_parser::{
/// - XML-style tags with JSON content
/// - Support for multiple sequential tool calls
/// - Newline-aware parsing
/// - Buffering for partial end tokens
pub struct QwenParser {
/// Parser for handling incomplete JSON during streaming
partial_json: PartialJson,
/// Regex for extracting tool calls
/// Regex for extracting tool calls in parse_complete
extractor: Regex,
/// Buffer for accumulating incomplete patterns across chunks
buffer: String,
/// Stores complete tool call info (name and arguments) for each tool being parsed
prev_tool_call_arr: Vec<Value>,
/// Index of currently streaming tool call (-1 means no active tool)
current_tool_id: i32,
/// Flag for whether current tool's name has been sent to client
current_tool_name_sent: bool,
/// Tracks raw JSON string content streamed to client for each tool's arguments
streamed_args_for_tool: Vec<String>,
/// Buffer for normal text that might precede partial end tokens
normal_text_buffer: String,
/// Token configuration
bot_token: &'static str,
eot_token: &'static str,
tool_call_separator: &'static str,
}
impl QwenParser {
......@@ -36,11 +63,20 @@ impl QwenParser {
Self {
partial_json: PartialJson::default(),
extractor,
buffer: String::new(),
prev_tool_call_arr: Vec::new(),
current_tool_id: -1,
current_tool_name_sent: false,
streamed_args_for_tool: Vec::new(),
normal_text_buffer: String::new(),
bot_token: "<tool_call>\n",
eot_token: "\n</tool_call>",
tool_call_separator: "\n",
}
}
/// Parse a single JSON object into a ToolCall
fn parse_single_object(&self, obj: &Value, index: usize) -> ToolParserResult<Option<ToolCall>> {
fn parse_single_object(&self, obj: &Value) -> ToolParserResult<Option<ToolCall>> {
let name = obj.get("name").and_then(|v| v.as_str());
if let Some(name) = name {
......@@ -52,8 +88,12 @@ impl QwenParser {
let arguments = serde_json::to_string(args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate ID with index for multiple tools
let id = format!("qwen_call_{}", index);
// Generate unique ID
let id = obj
.get("id")
.and_then(|v| v.as_str())
.map(String::from)
.unwrap_or_else(|| format!("qwen_call_{}", uuid::Uuid::new_v4()));
Ok(Some(ToolCall {
id,
......@@ -73,42 +113,9 @@ impl QwenParser {
text.contains("<tool_call>")
}
/// Find the start position of a tool call
fn find_tool_start(&self, text: &str) -> Option<usize> {
text.find("<tool_call>\n")
}
/// Find the end position of a tool call
fn find_tool_end(&self, text: &str, start_pos: usize) -> Option<usize> {
let search_from = start_pos + "<tool_call>\n".len();
text[search_from..]
.find("\n</tool_call>")
.map(|pos| search_from + pos + "\n</tool_call>".len())
}
/// Check if buffer ends with a partial token
fn ends_with_partial_token(&self, buffer: &str) -> Option<usize> {
// Check for partial start token
let start_token = "<tool_call>\n";
// Use inclusive range to check if entire buffer could be a prefix
for i in 1..=start_token.len().min(buffer.len()) {
if start_token.starts_with(&buffer[buffer.len() - i..]) {
return Some(i);
}
}
// Check for partial end token
let end_token = "\n</tool_call>";
// Only check if buffer ends with a partial match (not the complete token without newline)
// If buffer ends with "</tool_call>", that's not a partial token - it's missing the newline
if buffer.ends_with("</tool_call>") {
// This is a complete end tag, just missing the leading newline
// Not a partial token situation
return None;
}
// Use inclusive range to check if entire buffer could be a prefix
(1..=end_token.len().min(buffer.len()))
.find(|&i| end_token.starts_with(&buffer[buffer.len() - i..]))
/// Check if text has tool call
fn has_tool_call(&self, text: &str) -> bool {
text.contains("<tool_call>")
}
}
......@@ -132,17 +139,17 @@ impl ToolParser for QwenParser {
// Extract tool calls
let mut tools = Vec::new();
for (index, captures) in self.extractor.captures_iter(text).enumerate() {
for captures in self.extractor.captures_iter(text) {
if let Some(json_str) = captures.get(1) {
let parsed = serde_json::from_str::<Value>(json_str.as_str().trim())
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))
.and_then(|v| self.parse_single_object(&v, index));
.and_then(|v| self.parse_single_object(&v));
match parsed {
Ok(Some(tool)) => tools.push(tool),
Ok(None) => continue,
Err(e) => {
tracing::warn!("Failed to parse tool call {}: {:?}", index, e);
tracing::warn!("Failed to parse tool call: {:?}", e);
continue;
}
}
......@@ -158,103 +165,91 @@ impl ToolParser for QwenParser {
}
async fn parse_incremental(
&self,
&mut self,
chunk: &str,
state: &mut ParseState,
) -> ToolParserResult<StreamResult> {
state.buffer.push_str(chunk);
// Check for partial token at end of buffer
if let Some(_partial_len) = self.ends_with_partial_token(&state.buffer) {
// Hold back the partial token
return Ok(StreamResult::Incomplete);
}
// Check if we have the start marker
if !self.has_tool_markers(&state.buffer) {
// No tool markers detected - return all buffered content as normal text
let normal_text = std::mem::take(&mut state.buffer);
return Ok(StreamResult::NormalText(normal_text));
}
// Check for text before tool markers and extract it as normal text
if let Some(marker_pos) = state.buffer.find("<tool_call>") {
if marker_pos > 0 {
// We have text before the tool marker - extract it as normal text
let normal_text: String = state.buffer.drain(..marker_pos).collect();
return Ok(StreamResult::NormalText(normal_text));
tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> {
// Append new text to buffer
self.buffer.push_str(chunk);
let current_text = &self.buffer.clone();
// Check if current_text has tool_call
let has_tool_start = self.has_tool_call(current_text)
|| (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));
if !has_tool_start {
// Only clear buffer if we're sure no tool call is starting
if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
let normal_text = self.buffer.clone();
self.buffer.clear();
return Ok(StreamingParseResult {
normal_text,
calls: vec![],
});
} else {
// Might be partial bot_token, keep buffering
return Ok(StreamingParseResult::default());
}
}
// Find start and end positions
if let Some(start_pos) = self.find_tool_start(&state.buffer) {
// Check if we have the complete tool call
if let Some(end_pos) = self.find_tool_end(&state.buffer, start_pos) {
// Extract the JSON content
let json_start = start_pos + "<tool_call>\n".len();
let json_end = end_pos - "\n</tool_call>".len();
let json_str = &state.buffer[json_start..json_end];
// Parse the complete JSON
match serde_json::from_str::<Value>(json_str.trim()) {
Ok(value) => {
if let Some(tool) = self.parse_single_object(&value, 0)? {
// Clear the consumed part from buffer using drain for efficiency
state.buffer.drain(..end_pos);
return Ok(StreamResult::ToolComplete(tool));
}
}
Err(_) => {
// JSON parsing failed, might be incomplete or malformed
// If we have what looks like a complete tool call block, treat as normal text
if state.buffer[start_pos..end_pos].contains("\n</tool_call>") {
let malformed_text: String = state.buffer.drain(..end_pos).collect();
return Ok(StreamResult::NormalText(malformed_text));
}
}
}
// Build tool indices
let tool_indices = helpers::get_tool_indices(tools);
// Determine start index for JSON parsing
let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
pos + self.bot_token.len()
} else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
self.tool_call_separator.len()
} else {
0
};
let mut result = helpers::handle_json_tool_streaming(
current_text,
start_idx,
&mut self.partial_json,
&tool_indices,
&mut self.buffer,
&mut self.current_tool_id,
&mut self.current_tool_name_sent,
&mut self.streamed_args_for_tool,
&mut self.prev_tool_call_arr,
)?;
// Qwen-specific: Handle partial end tokens in normal text
// After tool calls complete, normal text might contain partial "</tool_call>" tags
if !result.normal_text.is_empty() {
self.normal_text_buffer.push_str(&result.normal_text);
// Check if buffer contains complete end token (without leading newline)
let end_token_without_newline = &self.eot_token[1..]; // "</tool_call>"
if self.normal_text_buffer.contains(end_token_without_newline) {
// Complete end token found - clean it and return
let cleaned_text = self
.normal_text_buffer
.replace(end_token_without_newline, "");
self.normal_text_buffer.clear();
result.normal_text = cleaned_text;
} else {
// We have start but no end yet - try partial parsing
let json_start = start_pos + "<tool_call>\n".len();
let partial_json = &state.buffer[json_start..];
// Remove trailing newline if present (might be start of end token)
let partial_json = partial_json.trim_end();
// Try to parse with partial JSON parser
match self.partial_json.parse_value(partial_json) {
Ok((value, _consumed)) => {
// Extract tool name if available
if let Some(name) = value.get("name").and_then(|v| v.as_str()) {
// Check if we've already sent the name
if !state.in_string {
state.in_string = true; // Use as flag for "name sent"
return Ok(StreamResult::ToolName {
index: 0,
name: name.to_string(),
});
}
// Check for arguments
if let Some(args) = value.get("arguments") {
if let Ok(args_str) = serde_json::to_string(args) {
return Ok(StreamResult::ToolArguments {
index: 0,
arguments: args_str,
});
}
}
}
}
Err(_) => {
// Failed to parse even as partial JSON
// Keep buffering
}
// Check if buffer might contain partial end token at the end
if let Some(partial_match_len) = helpers::ends_with_partial_token(
&self.normal_text_buffer,
end_token_without_newline,
) {
// Keep potential partial match in buffer, return the rest
let split_point = self.normal_text_buffer.len() - partial_match_len;
result.normal_text = self.normal_text_buffer[..split_point].to_string();
self.normal_text_buffer = self.normal_text_buffer[split_point..].to_string();
} else {
// No partial match, return all buffered text
result.normal_text = self.normal_text_buffer.clone();
self.normal_text_buffer.clear();
}
}
}
Ok(StreamResult::Incomplete)
Ok(result)
}
fn detect_format(&self, text: &str) -> bool {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment