Unverified Commit 816c4c85 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router] add tool parser base structure and partial json parser (#9482)

parent 13ec8d42
......@@ -48,6 +48,7 @@ metrics = "0.24.2"
metrics-exporter-prometheus = "0.17.0"
uuid = { version = "1.10", features = ["v4", "serde"] }
thiserror = "2.0.12"
regex = "1.10"
url = "2.5.4"
tokio-stream = { version = "0.1", features = ["sync"] }
anyhow = "1.0"
......
......@@ -100,7 +100,8 @@ fn bench_encode_throughput(c: &mut Criterion) {
let tokenizer_clone = tokenizer.clone();
// Get token count once
let token_count = tokenizer.encode(prompt).unwrap().token_ids().len();
let encoding = tokenizer.encode(prompt).unwrap();
let token_count = encoding.token_ids().len();
// Track if metrics have been printed for this test case
let printed = Arc::new(AtomicBool::new(false));
......@@ -157,7 +158,8 @@ fn bench_batch_encode(c: &mut Criterion) {
let batch_sizes = vec![1, 8, 16, 32, 64, 128];
let prompt = MEDIUM_PROMPT;
let prompt_len = prompt.len();
let token_count = tokenizer.encode(prompt).unwrap().token_ids().len();
let encoding = tokenizer.encode(prompt).unwrap();
let token_count = encoding.token_ids().len();
let mut group = c.benchmark_group("batch_encode");
......@@ -303,7 +305,8 @@ fn bench_decode_performance(c: &mut Criterion) {
);
let test_text = "The quick brown fox jumps over the lazy dog. ".repeat(10);
let tokens = tokenizer.encode(&test_text).unwrap().token_ids();
let encoding = tokenizer.encode(&test_text).unwrap();
let tokens = encoding.token_ids();
let num_tokens = tokens.len();
let mut group = c.benchmark_group("decode_performance");
......@@ -313,12 +316,11 @@ fn bench_decode_performance(c: &mut Criterion) {
group.bench_function("direct_decode", |b| {
let printed = printed_direct.clone();
let tokenizer = tokenizer.clone();
let tokens = tokens.clone();
b.iter_custom(|iters| {
let start = Instant::now();
for _ in 0..iters {
black_box(tokenizer.decode(&tokens, false).unwrap());
black_box(tokenizer.decode(tokens, false).unwrap());
}
let duration = start.elapsed();
......@@ -344,14 +346,13 @@ fn bench_decode_performance(c: &mut Criterion) {
group.bench_function("decode_stream", |b| {
let printed = printed_stream.clone();
let tokenizer = tokenizer.clone();
let tokens = tokens.clone();
b.iter_custom(|iters| {
let start = Instant::now();
for _ in 0..iters {
let mut decoder = DecodeStream::new(tokenizer.clone(), &[], false);
let mut output = String::new();
for token in &tokens {
for token in tokens {
if let Some(text) = decoder.step(*token).unwrap() {
output.push_str(&text);
}
......@@ -382,14 +383,13 @@ fn bench_decode_performance(c: &mut Criterion) {
group.bench_function("sequence_decode", |b| {
let printed = printed_seq.clone();
let tokenizer = tokenizer.clone();
let tokens = tokens.clone();
b.iter_custom(|iters| {
let start = Instant::now();
for _ in 0..iters {
let mut sequence = Sequence::new(tokenizer.clone());
let mut output = String::new();
for token in &tokens {
for token in tokens {
let text = sequence.append_token(*token).unwrap();
output.push_str(&text);
}
......@@ -424,7 +424,8 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
);
let sample_text = "The quick brown fox jumps over the lazy dog. ".repeat(1000);
let all_tokens = tokenizer.encode(&sample_text).unwrap().token_ids();
let encoding = tokenizer.encode(&sample_text).unwrap();
let all_tokens = encoding.token_ids();
let mut group = c.benchmark_group("streaming_100k");
group.measurement_time(Duration::from_secs(1));
......@@ -434,7 +435,6 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
group.bench_function("decode_stream_100k", |b| {
let printed = printed_stream.clone();
let tokenizer = tokenizer.clone();
let tokens = all_tokens.clone();
b.iter_custom(|_iters| {
let start = Instant::now();
......@@ -442,7 +442,7 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
let mut output = String::new();
let mut tokens_processed = 0u64;
for token in tokens.iter().cycle() {
for token in all_tokens.iter().cycle() {
if start.elapsed() >= Duration::from_millis(500) {
break;
}
......@@ -486,7 +486,6 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
group.bench_function("sequence_100k", |b| {
let printed = printed_seq.clone();
let tokenizer = tokenizer.clone();
let tokens = all_tokens.clone();
b.iter_custom(|_iters| {
let start = Instant::now();
......@@ -494,7 +493,7 @@ fn bench_streaming_decode_100k(c: &mut Criterion) {
let mut output = String::new();
let mut tokens_processed = 0u64;
for token in tokens.iter().cycle() {
for token in all_tokens.iter().cycle() {
if start.elapsed() >= Duration::from_millis(500) {
break;
}
......@@ -693,7 +692,8 @@ fn bench_concurrent_streaming(c: &mut Criterion) {
let tokens_per_sequence = 10_000;
let sample_text = "The quick brown fox jumps over the lazy dog. ".repeat(100);
let token_batch = tokenizer.encode(&sample_text).unwrap().token_ids();
let encoding = tokenizer.encode(&sample_text).unwrap();
let token_batch: Vec<u32> = encoding.token_ids().to_vec();
let mut group = c.benchmark_group("concurrent_streaming");
group.measurement_time(Duration::from_secs(2));
......@@ -775,7 +775,8 @@ fn bench_stop_sequences(c: &mut Criterion) {
.with_stop_token(2);
let sample_text = "Hello world! This is a test. ### Stop here. Continue after.".repeat(100);
let tokens = tokenizer.encode(&sample_text).unwrap().token_ids();
let encoding = tokenizer.encode(&sample_text).unwrap();
let tokens = encoding.token_ids();
let mut group = c.benchmark_group("stop_sequences");
......@@ -784,7 +785,6 @@ fn bench_stop_sequences(c: &mut Criterion) {
group.bench_function("no_stops", |b| {
let printed_clone = printed_no_stop.clone();
let tokenizer = tokenizer.clone();
let tokens = tokens.clone();
b.iter_custom(|iters| {
let start = Instant::now();
......@@ -796,7 +796,7 @@ fn bench_stop_sequences(c: &mut Criterion) {
StopSequenceConfig::default(),
false,
);
for token in &tokens {
for token in tokens {
let _ = decoder.process_token(*token).unwrap();
total_tokens += 1;
}
......@@ -826,7 +826,6 @@ fn bench_stop_sequences(c: &mut Criterion) {
group.bench_function("with_stops", |b| {
let printed_clone = printed_with_stops.clone();
let tokenizer = tokenizer.clone();
let tokens = tokens.clone();
let config = config.clone();
b.iter_custom(|iters| {
......@@ -839,7 +838,7 @@ fn bench_stop_sequences(c: &mut Criterion) {
StopSequenceDecoder::new(tokenizer.clone(), config.clone(), false);
let mut sequence_tokens = 0u64;
for token in &tokens {
for token in tokens {
let result = decoder.process_token(*token).unwrap();
sequence_tokens += 1;
......@@ -986,7 +985,8 @@ fn bench_multithreaded_decode(c: &mut Criterion) {
// Generate tokens for decoding
let test_text = "The quick brown fox jumps over the lazy dog. ".repeat(100);
let test_tokens = tokenizer.encode(&test_text).unwrap().token_ids();
let encoding = tokenizer.encode(&test_text).unwrap();
let test_tokens: Vec<u32> = encoding.token_ids().to_vec();
let mut group = c.benchmark_group("multithreaded_decode");
group.measurement_time(Duration::from_secs(2));
......@@ -1130,7 +1130,7 @@ fn bench_memory_efficiency(c: &mut Criterion) {
b.iter_custom(|iters| {
let start = Instant::now();
for _ in 0..iters {
let _ = black_box(encoding.token_ids_ref());
let _ = black_box(encoding.token_ids());
}
let duration = start.elapsed();
......
......@@ -14,6 +14,7 @@ pub mod routers;
pub mod server;
pub mod service_discovery;
pub mod tokenizer;
pub mod tool_parser;
pub mod tree;
use crate::metrics::PrometheusConfig;
......
use thiserror::Error;
/// Result type for tool parser operations
pub type ToolParserResult<T> = Result<T, ToolParserError>;
/// Errors that can occur during tool parsing
#[derive(Debug, Error)]
pub enum ToolParserError {
#[error("Parsing failed: {0}")]
ParsingFailed(String),
#[error("Model not supported: {0}")]
ModelNotSupported(String),
#[error("Parse depth exceeded: max {0}")]
DepthExceeded(usize),
#[error("Invalid JSON: {0}")]
JsonError(#[from] serde_json::Error),
#[error("Regex error: {0}")]
RegexError(#[from] regex::Error),
#[error("Incomplete tool call")]
Incomplete,
#[error("Invalid tool name: {0}")]
InvalidToolName(String),
#[error("Token not found: {0}")]
TokenNotFound(String),
}
/// Tool parser module for handling function/tool calls in model outputs
///
/// This module provides infrastructure for parsing tool calls from various model formats.
/// Phase 1 focuses on core infrastructure: types, traits, registry, and partial JSON parsing.
pub mod errors;
pub mod partial_json;
pub mod registry;
pub mod state;
pub mod traits;
pub mod types;
#[cfg(test)]
mod tests;
// Re-export commonly used types
pub use errors::{ToolParserError, ToolParserResult};
pub use registry::ParserRegistry;
pub use state::{ParsePhase, ParseState};
pub use traits::{PartialJsonParser, ToolParser};
pub use types::{FunctionCall, PartialToolCall, StreamResult, TokenConfig, ToolCall};
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
traits::PartialJsonParser,
};
use serde_json::{Map, Value};
/// Parser for incomplete JSON
pub struct PartialJson {
/// Maximum depth for nested structures
max_depth: usize,
/// Whether to allow incomplete values
allow_incomplete: bool,
}
impl PartialJson {
/// Create a new partial JSON parser
pub fn new(max_depth: usize, allow_incomplete: bool) -> Self {
Self {
max_depth,
allow_incomplete,
}
}
/// Parse potentially incomplete JSON, returning parsed value and consumed bytes
pub fn parse_value(&self, input: &str) -> ToolParserResult<(Value, usize)> {
let mut parser = Parser::new(input, self.max_depth, self.allow_incomplete);
let value = parser.parse_value(0)?;
Ok((value, parser.position))
}
}
impl Default for PartialJson {
fn default() -> Self {
Self::new(32, true)
}
}
impl PartialJsonParser for PartialJson {
fn parse(&self, input: &str) -> ToolParserResult<(Value, usize)> {
self.parse_value(input)
}
fn is_complete(&self, input: &str) -> bool {
// Try to parse as complete JSON
serde_json::from_str::<Value>(input).is_ok()
}
fn max_depth(&self) -> usize {
self.max_depth
}
}
/// Internal parser state
struct Parser<'a> {
chars: std::iter::Peekable<std::str::Chars<'a>>,
position: usize,
max_depth: usize,
allow_incomplete: bool,
}
impl<'a> Parser<'a> {
fn new(input: &'a str, max_depth: usize, allow_incomplete: bool) -> Self {
Self {
chars: input.chars().peekable(),
position: 0,
max_depth,
allow_incomplete,
}
}
fn peek(&mut self) -> Option<char> {
self.chars.peek().copied()
}
fn advance(&mut self) {
if self.chars.next().is_some() {
self.position += 1;
}
}
fn skip_whitespace(&mut self) {
while let Some(ch) = self.peek() {
if ch.is_whitespace() {
self.advance();
} else {
break;
}
}
}
fn parse_value(&mut self, depth: usize) -> ToolParserResult<Value> {
if depth > self.max_depth {
return Err(ToolParserError::DepthExceeded(self.max_depth));
}
self.skip_whitespace();
match self.peek() {
Some('{') => self.parse_object(depth + 1),
Some('[') => self.parse_array(depth + 1),
Some('"') => self.parse_string(),
Some('t') | Some('f') => self.parse_bool(),
Some('n') => self.parse_null(),
Some(c) if c == '-' || c.is_ascii_digit() => self.parse_number(),
_ => {
if self.allow_incomplete {
Ok(Value::Null)
} else {
Err(ToolParserError::ParsingFailed(
"Unexpected character".into(),
))
}
}
}
}
fn parse_object(&mut self, depth: usize) -> ToolParserResult<Value> {
if depth > self.max_depth {
return Err(ToolParserError::DepthExceeded(self.max_depth));
}
let mut object = Map::new();
// Consume '{'
self.advance();
self.skip_whitespace();
// Check for empty object
if self.peek() == Some('}') {
self.advance();
return Ok(Value::Object(object));
}
loop {
// Parse key
let key = match self.parse_string() {
Ok(Value::String(s)) => s,
Err(_) if self.allow_incomplete => {
// Incomplete object
return Ok(Value::Object(object));
}
Err(e) => return Err(e),
_ => return Err(ToolParserError::ParsingFailed("Expected string key".into())),
};
self.skip_whitespace();
// Expect ':'
if self.peek() != Some(':') {
if self.allow_incomplete {
// Add null value for incomplete pair
object.insert(key, Value::Null);
return Ok(Value::Object(object));
}
return Err(ToolParserError::ParsingFailed("Expected ':'".into()));
}
self.advance();
self.skip_whitespace();
// Parse value (keep same depth - we already incremented in parse_object)
let value = match self.parse_value(depth) {
Ok(v) => v,
Err(_) if self.allow_incomplete => {
// Add null for incomplete value
object.insert(key, Value::Null);
return Ok(Value::Object(object));
}
Err(e) => return Err(e),
};
object.insert(key, value);
self.skip_whitespace();
match self.peek() {
Some(',') => {
self.advance();
self.skip_whitespace();
// Check for trailing comma
if self.peek() == Some('}') {
self.advance();
return Ok(Value::Object(object));
}
}
Some('}') => {
self.advance();
return Ok(Value::Object(object));
}
None if self.allow_incomplete => {
return Ok(Value::Object(object));
}
_ => {
if self.allow_incomplete {
return Ok(Value::Object(object));
}
return Err(ToolParserError::ParsingFailed("Expected ',' or '}'".into()));
}
}
}
}
fn parse_array(&mut self, depth: usize) -> ToolParserResult<Value> {
if depth > self.max_depth {
return Err(ToolParserError::DepthExceeded(self.max_depth));
}
let mut array = Vec::new();
// Consume '['
self.advance();
self.skip_whitespace();
// Check for empty array
if self.peek() == Some(']') {
self.advance();
return Ok(Value::Array(array));
}
loop {
// Parse value (keep same depth - we already incremented in parse_object)
let value = match self.parse_value(depth) {
Ok(v) => v,
Err(_) if self.allow_incomplete => {
return Ok(Value::Array(array));
}
Err(e) => return Err(e),
};
array.push(value);
self.skip_whitespace();
match self.peek() {
Some(',') => {
self.advance();
self.skip_whitespace();
// Check for trailing comma
if self.peek() == Some(']') {
self.advance();
return Ok(Value::Array(array));
}
}
Some(']') => {
self.advance();
return Ok(Value::Array(array));
}
None if self.allow_incomplete => {
return Ok(Value::Array(array));
}
_ => {
if self.allow_incomplete {
return Ok(Value::Array(array));
}
return Err(ToolParserError::ParsingFailed("Expected ',' or ']'".into()));
}
}
}
}
fn parse_string(&mut self) -> ToolParserResult<Value> {
if self.peek() != Some('"') {
return Err(ToolParserError::ParsingFailed("Expected '\"'".into()));
}
// Consume opening quote
self.advance();
let mut string = String::new();
let mut escaped = false;
while let Some(ch) = self.peek() {
if escaped {
// Handle escape sequences
let escaped_char = match ch {
'"' | '\\' | '/' => ch,
'b' => '\u{0008}',
'f' => '\u{000C}',
'n' => '\n',
'r' => '\r',
't' => '\t',
'u' => {
// Unicode escape
self.advance();
let hex = self.parse_unicode_escape()?;
string.push(hex);
escaped = false;
continue;
}
_ => ch, // Invalid escape, but be lenient
};
string.push(escaped_char);
escaped = false;
} else if ch == '\\' {
escaped = true;
} else if ch == '"' {
// End of string
self.advance();
return Ok(Value::String(string));
} else {
string.push(ch);
}
self.advance();
}
// Incomplete string
if self.allow_incomplete {
Ok(Value::String(string))
} else {
Err(ToolParserError::ParsingFailed("Unterminated string".into()))
}
}
fn parse_unicode_escape(&mut self) -> ToolParserResult<char> {
let mut hex = String::new();
for _ in 0..4 {
if let Some(ch) = self.peek() {
if ch.is_ascii_hexdigit() {
hex.push(ch);
self.advance();
} else {
break;
}
} else {
break;
}
}
if hex.len() == 4 {
u32::from_str_radix(&hex, 16)
.ok()
.and_then(char::from_u32)
.ok_or_else(|| ToolParserError::ParsingFailed("Invalid unicode escape".into()))
} else if self.allow_incomplete {
Ok('\u{FFFD}') // Replacement character
} else {
Err(ToolParserError::ParsingFailed(
"Incomplete unicode escape".into(),
))
}
}
fn parse_number(&mut self) -> ToolParserResult<Value> {
let mut number = String::new();
// Handle negative sign
if self.peek() == Some('-') {
number.push('-');
self.advance();
}
// Parse integer part
if self.peek() == Some('0') {
number.push('0');
self.advance();
} else {
while let Some(ch) = self.peek() {
if ch.is_ascii_digit() {
number.push(ch);
self.advance();
} else {
break;
}
}
}
// Parse decimal part
if self.peek() == Some('.') {
number.push('.');
self.advance();
while let Some(ch) = self.peek() {
if ch.is_ascii_digit() {
number.push(ch);
self.advance();
} else {
break;
}
}
}
// Parse exponent
if let Some(ch) = self.peek() {
if ch == 'e' || ch == 'E' {
number.push(ch);
self.advance();
if let Some(sign) = self.peek() {
if sign == '+' || sign == '-' {
number.push(sign);
self.advance();
}
}
while let Some(ch) = self.peek() {
if ch.is_ascii_digit() {
number.push(ch);
self.advance();
} else {
break;
}
}
}
}
// Try to parse as integer first, then as float
if let Ok(n) = number.parse::<i64>() {
Ok(Value::Number(serde_json::Number::from(n)))
} else if let Ok(n) = number.parse::<f64>() {
Ok(Value::Number(
serde_json::Number::from_f64(n).unwrap_or_else(|| serde_json::Number::from(0)),
))
} else if self.allow_incomplete {
Ok(Value::Number(serde_json::Number::from(0)))
} else {
Err(ToolParserError::ParsingFailed("Invalid number".into()))
}
}
fn parse_bool(&mut self) -> ToolParserResult<Value> {
let mut word = String::new();
// Peek at upcoming characters to validate it looks like a boolean
let mut temp_chars = self.chars.clone();
while let Some(&ch) = temp_chars.peek() {
if ch.is_alphabetic() && word.len() < 5 {
// "false" is 5 chars
word.push(ch);
temp_chars.next();
} else {
break;
}
}
// Check if it's a valid boolean prefix
let is_valid = word == "true"
|| word == "false"
|| (self.allow_incomplete && ("true".starts_with(&word) || "false".starts_with(&word)));
if !is_valid {
return Err(ToolParserError::ParsingFailed("Invalid boolean".into()));
}
// Now actually consume the characters
word.clear();
while let Some(ch) = self.peek() {
if ch.is_alphabetic() {
word.push(ch);
self.advance();
} else {
break;
}
}
match word.as_str() {
"true" => Ok(Value::Bool(true)),
"false" => Ok(Value::Bool(false)),
partial if self.allow_incomplete => {
if "true".starts_with(partial) {
Ok(Value::Bool(true))
} else if "false".starts_with(partial) {
Ok(Value::Bool(false))
} else {
Err(ToolParserError::ParsingFailed("Invalid boolean".into()))
}
}
_ => Err(ToolParserError::ParsingFailed("Invalid boolean".into())),
}
}
fn parse_null(&mut self) -> ToolParserResult<Value> {
let mut word = String::new();
// Peek at upcoming characters to validate it looks like "null"
let mut temp_chars = self.chars.clone();
while let Some(&ch) = temp_chars.peek() {
if ch.is_alphabetic() && word.len() < 4 {
// "null" is 4 chars
word.push(ch);
temp_chars.next();
} else {
break;
}
}
// Check if it's a valid null prefix
let is_valid = word == "null" || (self.allow_incomplete && "null".starts_with(&word));
if !is_valid {
return Err(ToolParserError::ParsingFailed("Invalid null".into()));
}
// Now actually consume the characters
word.clear();
while let Some(ch) = self.peek() {
if ch.is_alphabetic() {
word.push(ch);
self.advance();
} else {
break;
}
}
if word == "null" || (self.allow_incomplete && "null".starts_with(&word)) {
Ok(Value::Null)
} else {
Err(ToolParserError::ParsingFailed("Invalid null".into()))
}
}
}
/// Utility function to check if a string contains complete JSON
pub fn is_complete_json(input: &str) -> bool {
serde_json::from_str::<Value>(input).is_ok()
}
/// Utility function to find common prefix between two strings
pub fn find_common_prefix(s1: &str, s2: &str) -> usize {
s1.chars()
.zip(s2.chars())
.take_while(|(a, b)| a == b)
.count()
}
/// Utility function to compute diff between old and new strings
pub fn compute_diff(old: &str, new: &str) -> String {
let common_len = find_common_prefix(old, new);
// Convert character count to byte offset
new.chars().skip(common_len).collect()
}
use crate::tool_parser::traits::ToolParser;
use std::collections::HashMap;
use std::sync::Arc;
/// Registry for tool parsers and model mappings
pub struct ParserRegistry {
/// Map of parser name to parser instance
parsers: HashMap<String, Arc<dyn ToolParser>>,
/// Map of model name/pattern to parser name
model_mapping: HashMap<String, String>,
/// Default parser to use when no match found
default_parser: String,
}
impl ParserRegistry {
/// Create a new parser registry with default mappings
pub fn new() -> Self {
let mut registry = Self {
parsers: HashMap::new(),
model_mapping: HashMap::new(),
default_parser: "json".to_string(),
};
// Register default model mappings
registry.register_default_mappings();
registry
}
/// Register a parser
pub fn register_parser(&mut self, name: impl Into<String>, parser: Arc<dyn ToolParser>) {
self.parsers.insert(name.into(), parser);
}
/// Map a model name/pattern to a parser
pub fn map_model(&mut self, model: impl Into<String>, parser: impl Into<String>) {
self.model_mapping.insert(model.into(), parser.into());
}
/// Get parser for a specific model
pub fn get_parser(&self, model: &str) -> Option<Arc<dyn ToolParser>> {
// Try exact match first
if let Some(parser_name) = self.model_mapping.get(model) {
if let Some(parser) = self.parsers.get(parser_name) {
return Some(parser.clone());
}
}
// Try prefix matching (e.g., "gpt-4" matches "gpt-*")
for (pattern, parser_name) in &self.model_mapping {
if pattern.ends_with('*') {
let prefix = &pattern[..pattern.len() - 1];
if model.starts_with(prefix) {
if let Some(parser) = self.parsers.get(parser_name) {
return Some(parser.clone());
}
}
}
}
// Fall back to default parser if it exists
self.parsers.get(&self.default_parser).cloned()
}
/// List all registered parsers
pub fn list_parsers(&self) -> Vec<&str> {
self.parsers.keys().map(|s| s.as_str()).collect()
}
/// List all model mappings
pub fn list_mappings(&self) -> Vec<(&str, &str)> {
self.model_mapping
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect()
}
/// Register default model mappings
fn register_default_mappings(&mut self) {
// OpenAI models
self.map_model("gpt-4*", "json");
self.map_model("gpt-3.5*", "json");
self.map_model("gpt-4o*", "json");
// Anthropic models
self.map_model("claude-*", "json");
// Mistral models
self.map_model("mistral-*", "mistral");
self.map_model("mixtral-*", "mistral");
// Qwen models
self.map_model("qwen*", "qwen");
// Llama models
self.map_model("llama-*", "llama");
self.map_model("meta-llama-*", "llama");
// Other models default to JSON
self.map_model("gemini-*", "json");
self.map_model("palm-*", "json");
}
/// Set the default parser
pub fn set_default_parser(&mut self, name: impl Into<String>) {
self.default_parser = name.into();
}
/// Check if a parser is registered
pub fn has_parser(&self, name: &str) -> bool {
self.parsers.contains_key(name)
}
}
impl Default for ParserRegistry {
fn default() -> Self {
Self::new()
}
}
use crate::tool_parser::types::{PartialToolCall, ToolCall};
/// Current phase of parsing
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ParsePhase {
/// Looking for start of tool call
Searching,
/// Parsing function name
InName,
/// Parsing function arguments
InArguments,
/// Tool call complete
Complete,
}
/// State for streaming parser
#[derive(Debug, Clone)]
pub struct ParseState {
/// Buffer for accumulating input
pub buffer: String,
/// Position of last consumed character
pub consumed: usize,
/// Current partial tool being parsed
pub partial_tool: Option<PartialToolCall>,
/// Completed tool calls
pub completed_tools: Vec<ToolCall>,
/// Current parsing phase
pub phase: ParsePhase,
/// Bracket/brace depth for JSON parsing
pub bracket_depth: i32,
/// Whether currently inside a string literal
pub in_string: bool,
/// Whether next character should be escaped
pub escape_next: bool,
/// Current tool index (for streaming)
pub tool_index: usize,
}
impl ParseState {
/// Create a new parse state
pub fn new() -> Self {
Self {
buffer: String::new(),
consumed: 0,
partial_tool: None,
completed_tools: Vec::new(),
phase: ParsePhase::Searching,
bracket_depth: 0,
in_string: false,
escape_next: false,
tool_index: 0,
}
}
/// Reset state for parsing next tool
pub fn reset(&mut self) {
self.partial_tool = None;
self.phase = ParsePhase::Searching;
self.bracket_depth = 0;
self.in_string = false;
self.escape_next = false;
}
/// Process a single character for JSON parsing
pub fn process_char(&mut self, ch: char) {
// Handle escape sequences
if self.escape_next {
self.escape_next = false;
self.buffer.push(ch);
return;
}
if ch == '\\' && self.in_string {
self.escape_next = true;
self.buffer.push(ch);
return;
}
// Track string boundaries
if ch == '"' && !self.escape_next {
self.in_string = !self.in_string;
}
// Track bracket depth for JSON
if !self.in_string {
match ch {
'{' | '[' => {
self.bracket_depth += 1;
}
'}' | ']' => {
self.bracket_depth -= 1;
if self.bracket_depth == 0 && self.partial_tool.is_some() {
// Complete tool call found
self.phase = ParsePhase::Complete;
}
}
_ => {}
}
}
self.buffer.push(ch);
}
/// Check if we have a complete JSON object/array
pub fn has_complete_json(&self) -> bool {
self.bracket_depth == 0 && !self.in_string && !self.buffer.is_empty()
}
/// Extract content from buffer starting at position
pub fn extract_from(&self, start: usize) -> &str {
if start >= self.buffer.len() {
return "";
}
// Find the nearest character boundary at or after start
let mut safe_start = start;
while safe_start < self.buffer.len() && !self.buffer.is_char_boundary(safe_start) {
safe_start += 1;
}
if safe_start < self.buffer.len() {
&self.buffer[safe_start..]
} else {
""
}
}
/// Mark content as consumed up to position
pub fn consume_to(&mut self, position: usize) {
if position > self.consumed {
self.consumed = position;
}
}
/// Get unconsumed content
pub fn unconsumed(&self) -> &str {
if self.consumed >= self.buffer.len() {
return "";
}
// Find the nearest character boundary at or after consumed
let mut safe_consumed = self.consumed;
while safe_consumed < self.buffer.len() && !self.buffer.is_char_boundary(safe_consumed) {
safe_consumed += 1;
}
if safe_consumed < self.buffer.len() {
&self.buffer[safe_consumed..]
} else {
""
}
}
/// Clear consumed content from buffer
pub fn clear_consumed(&mut self) {
if self.consumed > 0 {
// Find the nearest character boundary at or before consumed
let mut safe_consumed = self.consumed;
while safe_consumed > 0 && !self.buffer.is_char_boundary(safe_consumed) {
safe_consumed -= 1;
}
if safe_consumed > 0 {
self.buffer.drain(..safe_consumed);
self.consumed = self.consumed.saturating_sub(safe_consumed);
}
}
}
/// Add completed tool
pub fn add_completed_tool(&mut self, tool: ToolCall) {
self.completed_tools.push(tool);
self.tool_index += 1;
}
}
impl Default for ParseState {
fn default() -> Self {
Self::new()
}
}
use super::*;
use crate::tool_parser::partial_json::{
compute_diff, find_common_prefix, is_complete_json, PartialJson,
};
#[test]
fn test_parse_state_new() {
let state = ParseState::new();
assert_eq!(state.phase, ParsePhase::Searching);
assert_eq!(state.buffer, "");
assert_eq!(state.consumed, 0);
assert_eq!(state.bracket_depth, 0);
assert!(!state.in_string);
assert!(!state.escape_next);
}
#[test]
fn test_parse_state_process_char() {
let mut state = ParseState::new();
// Test bracket tracking
state.process_char('{');
assert_eq!(state.bracket_depth, 1);
state.process_char('}');
assert_eq!(state.bracket_depth, 0);
// Test string tracking
state.process_char('"');
assert!(state.in_string);
state.process_char('"');
assert!(!state.in_string);
// Test escape handling
state.process_char('"');
state.process_char('\\');
assert!(state.escape_next);
state.process_char('"');
assert!(!state.escape_next);
assert!(state.in_string); // Still in string because quote was escaped
}
#[test]
fn test_token_config() {
let config = TokenConfig {
start_tokens: vec!["<start>".to_string(), "[".to_string()],
end_tokens: vec!["</end>".to_string(), "]".to_string()],
separator: ", ".to_string(),
};
let pairs: Vec<_> = config.iter_pairs().collect();
assert_eq!(pairs.len(), 2);
assert_eq!(pairs[0], ("<start>", "</end>"));
assert_eq!(pairs[1], ("[", "]"));
}
#[test]
fn test_parser_registry() {
let registry = ParserRegistry::new();
// Test has default mappings
assert!(!registry.list_mappings().is_empty());
// Test model pattern matching
let mappings = registry.list_mappings();
let has_gpt = mappings.iter().any(|(m, _)| m.starts_with("gpt"));
assert!(has_gpt);
}
#[test]
fn test_parser_registry_pattern_matching() {
let mut registry = ParserRegistry::new();
// Test that model mappings work by checking the list
registry.map_model("test-model", "json");
// Verify through list_mappings
let mappings = registry.list_mappings();
let has_test = mappings
.iter()
.any(|(m, p)| *m == "test-model" && *p == "json");
assert!(has_test);
}
#[test]
fn test_tool_call_serialization() {
let tool_call = ToolCall {
id: "call-123".to_string(),
r#type: "function".to_string(),
function: FunctionCall {
name: "search".to_string(),
arguments: r#"{"query": "rust programming"}"#.to_string(),
},
};
let json = serde_json::to_string(&tool_call).unwrap();
assert!(json.contains("call-123"));
assert!(json.contains("search"));
assert!(json.contains("rust programming"));
let parsed: ToolCall = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.id, "call-123");
assert_eq!(parsed.function.name, "search");
}
#[test]
fn test_partial_json_parser() {
let parser = PartialJson::default();
// Test complete JSON
let input = r#"{"name": "test", "value": 42}"#;
let (value, consumed) = parser.parse_value(input).unwrap();
assert_eq!(value["name"], "test");
assert_eq!(value["value"], 42);
assert_eq!(consumed, input.len());
// Test incomplete JSON object
let input = r#"{"name": "test", "value": "#;
let (value, _consumed) = parser.parse_value(input).unwrap();
assert_eq!(value["name"], "test");
assert!(value["value"].is_null());
// Test incomplete string
let input = r#"{"name": "tes"#;
let (value, _consumed) = parser.parse_value(input).unwrap();
assert_eq!(value["name"], "tes");
// Test incomplete array
let input = r#"[1, 2, "#;
let (value, _consumed) = parser.parse_value(input).unwrap();
assert!(value.is_array());
assert_eq!(value[0], 1);
assert_eq!(value[1], 2);
}
#[test]
fn test_partial_json_depth_limit() {
// max_depth of 3 allows nesting up to 3 levels
// Set allow_incomplete to false to get errors instead of partial results
let parser = PartialJson::new(3, false);
// This should work (simple object)
let input = r#"{"a": 1}"#;
let result = parser.parse_value(input);
assert!(result.is_ok());
// This should work (nested to depth 3)
let input = r#"{"a": {"b": {"c": 1}}}"#;
let result = parser.parse_value(input);
assert!(result.is_ok());
// This should fail (nested to depth 4, exceeds limit)
let input = r#"{"a": {"b": {"c": {"d": 1}}}}"#;
let result = parser.parse_value(input);
assert!(result.is_err());
}
#[test]
fn test_is_complete_json() {
assert!(is_complete_json(r#"{"name": "test"}"#));
assert!(is_complete_json(r#"[1, 2, 3]"#));
assert!(is_complete_json(r#""string""#));
assert!(is_complete_json("42"));
assert!(is_complete_json("true"));
assert!(is_complete_json("null"));
assert!(!is_complete_json(r#"{"name": "#));
assert!(!is_complete_json(r#"[1, 2, "#));
assert!(!is_complete_json(r#""unclosed"#));
}
#[test]
fn test_find_common_prefix() {
assert_eq!(find_common_prefix("hello", "hello"), 5);
assert_eq!(find_common_prefix("hello", "help"), 3);
assert_eq!(find_common_prefix("hello", "world"), 0);
assert_eq!(find_common_prefix("", "hello"), 0);
assert_eq!(find_common_prefix("hello", ""), 0);
}
#[test]
fn test_compute_diff() {
assert_eq!(compute_diff("hello", "hello world"), " world");
assert_eq!(compute_diff("", "hello"), "hello");
assert_eq!(compute_diff("hello", "hello"), "");
assert_eq!(compute_diff("test", "hello"), "hello");
}
#[test]
fn test_stream_result_variants() {
// Test Incomplete
let result = StreamResult::Incomplete;
matches!(result, StreamResult::Incomplete);
// Test ToolName
let result = StreamResult::ToolName {
index: 0,
name: "test".to_string(),
};
if let StreamResult::ToolName { index, name } = result {
assert_eq!(index, 0);
assert_eq!(name, "test");
} else {
panic!("Expected ToolName variant");
}
// Test ToolComplete
let tool = ToolCall {
id: "123".to_string(),
r#type: "function".to_string(),
function: FunctionCall {
name: "test".to_string(),
arguments: "{}".to_string(),
},
};
let result = StreamResult::ToolComplete(tool.clone());
if let StreamResult::ToolComplete(t) = result {
assert_eq!(t.id, "123");
} else {
panic!("Expected ToolComplete variant");
}
}
#[test]
fn test_partial_tool_call() {
let mut partial = PartialToolCall {
name: None,
arguments_buffer: String::new(),
start_position: 0,
name_sent: false,
streamed_args: String::new(),
};
// Set name
partial.name = Some("test_function".to_string());
assert_eq!(partial.name.as_ref().unwrap(), "test_function");
// Append arguments
partial.arguments_buffer.push_str(r#"{"key": "value"}"#);
assert_eq!(partial.arguments_buffer, r#"{"key": "value"}"#);
// Update streaming state
partial.name_sent = true;
partial.streamed_args = r#"{"key": "#.to_string();
assert!(partial.name_sent);
assert_eq!(partial.streamed_args, r#"{"key": "#);
}
use crate::tool_parser::{
errors::ToolParserResult,
state::ParseState,
types::{StreamResult, ToolCall},
};
use async_trait::async_trait;
/// Core trait for all tool parsers
#[async_trait]
pub trait ToolParser: Send + Sync {
/// Parse complete tool calls from final output
async fn parse_complete(&self, output: &str) -> ToolParserResult<Vec<ToolCall>>;
/// Parse tool calls from model output (streaming)
async fn parse_incremental(
&self,
chunk: &str,
state: &mut ParseState,
) -> ToolParserResult<StreamResult>;
/// Check if text contains tool calls in this parser's format
fn detect_format(&self, text: &str) -> bool;
}
/// Trait for partial JSON parsing
pub trait PartialJsonParser: Send + Sync {
/// Parse potentially incomplete JSON
fn parse(&self, input: &str) -> ToolParserResult<(serde_json::Value, usize)>;
/// Check if JSON is complete
fn is_complete(&self, input: &str) -> bool;
/// Get the maximum parsing depth
fn max_depth(&self) -> usize;
}
use serde::{Deserialize, Serialize};
/// Parsed tool call from model output (OpenAI format)
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCall {
/// Unique identifier for the tool call
pub id: String,
/// Type of tool call (currently always "function")
#[serde(rename = "type")]
pub r#type: String,
/// Function call details
pub function: FunctionCall,
}
/// Function call within a tool call
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct FunctionCall {
/// Name of the function to call
pub name: String,
/// Arguments as JSON string
pub arguments: String,
}
/// Streaming parse result
#[derive(Debug, Clone)]
pub enum StreamResult {
/// Need more data to continue parsing
Incomplete,
/// Found a tool name (for streaming)
ToolName { index: usize, name: String },
/// Found incremental arguments (for streaming)
ToolArguments { index: usize, arguments: String },
/// Completed parsing a tool
ToolComplete(ToolCall),
/// Normal text (not part of tool call)
NormalText(String),
}
/// Token configuration for parsing
#[derive(Debug, Clone)]
pub struct TokenConfig {
/// Start tokens for tool calls
pub start_tokens: Vec<String>,
/// End tokens for tool calls
pub end_tokens: Vec<String>,
/// Separator between multiple tool calls
pub separator: String,
}
impl TokenConfig {
/// Iterate over start/end token pairs
pub fn iter_pairs(&self) -> impl Iterator<Item = (&str, &str)> {
self.start_tokens
.iter()
.zip(self.end_tokens.iter())
.map(|(s, e)| (s.as_str(), e.as_str()))
}
}
/// Simple partial tool call for streaming
#[derive(Debug, Clone)]
pub struct PartialToolCall {
/// Tool name (if parsed)
pub name: Option<String>,
/// Buffer for accumulating arguments
pub arguments_buffer: String,
/// Start position in the input buffer
pub start_position: usize,
/// Whether the name has been sent (for streaming)
pub name_sent: bool,
/// Arguments already streamed
pub streamed_args: String,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment