json_parser.rs 13.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;

use crate::tool_parser::{
    errors::{ToolParserError, ToolParserResult},
    partial_json::PartialJson,
    state::ParseState,
    traits::ToolParser,
    types::{FunctionCall, StreamResult, ToolCall},
};

/// JSON format parser for tool calls
///
/// Handles various JSON formats for function calling:
/// - Single tool call: {"name": "fn", "arguments": {...}}
/// - Multiple tool calls: [{"name": "fn1", "arguments": {...}}, ...]
/// - With parameters instead of arguments: {"name": "fn", "parameters": {...}}
///
/// Supports configurable token markers for different models
pub struct JsonParser {
    /// Token(s) that mark the start of tool calls
    start_tokens: Vec<String>,
    /// Token(s) that mark the end of tool calls
    end_tokens: Vec<String>,
    /// Separator between multiple tool calls (reserved for future use)
    _separator: String,
    /// Parser for handling incomplete JSON during streaming
    partial_json: PartialJson,
    /// Regex patterns for extracting content between tokens
    extractors: Vec<Regex>,
}

impl JsonParser {
    /// Create a new JSON parser with default configuration
    pub fn new() -> Self {
        Self::with_config(
            vec![], // No wrapper tokens by default
            vec![],
            ", ".to_string(),
        )
    }

    /// Create a parser with custom token configuration
    pub fn with_config(
        start_tokens: Vec<String>,
        end_tokens: Vec<String>,
        separator: String,
    ) -> Self {
        // Build extraction patterns for each token pair
        let extractors = start_tokens
            .iter()
            .zip(end_tokens.iter())
            .filter_map(|(start, end)| {
                if !start.is_empty() && !end.is_empty() {
                    // Use (?s) flag to enable DOTALL mode so . matches newlines
                    let pattern =
                        format!(r"(?s){}(.*?){}", regex::escape(start), regex::escape(end));
                    Regex::new(&pattern).ok()
                } else {
                    None
                }
            })
            .collect();

        Self {
            start_tokens,
            end_tokens,
            _separator: separator,
            partial_json: PartialJson::default(),
            extractors,
        }
    }

    /// Extract JSON content from text, handling wrapper tokens if configured
    fn extract_json_content<'a>(&self, text: &'a str) -> &'a str {
        let mut content = text.trim();

        // Try each extractor pattern
        for extractor in &self.extractors {
            if let Some(captures) = extractor.captures(content) {
                if let Some(matched) = captures.get(1) {
                    content = matched.as_str().trim();
                    break;
                }
            }
        }

        // Handle special case where there's a start token but no end token
        for (start, end) in self.start_tokens.iter().zip(self.end_tokens.iter()) {
            if !start.is_empty() && end.is_empty() {
                content = content.strip_prefix(start).unwrap_or(content);
            }
        }

        content
    }

    /// Parse a single JSON object into a ToolCall
    fn parse_single_object(&self, obj: &Value) -> ToolParserResult<Option<ToolCall>> {
        // Check if this looks like a tool call
        let name = obj
            .get("name")
            .or_else(|| obj.get("function"))
            .and_then(|v| v.as_str());

        if let Some(name) = name {
            // Get arguments - support both "arguments" and "parameters" keys
            let empty_obj = Value::Object(serde_json::Map::new());
            let args = obj
                .get("arguments")
                .or_else(|| obj.get("parameters"))
                .unwrap_or(&empty_obj);

            // Convert arguments to JSON string
            let arguments = serde_json::to_string(args)
                .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;

            // Generate a unique ID if not provided
            let id = obj
                .get("id")
                .and_then(|v| v.as_str())
                .map(String::from)
                .unwrap_or_else(|| format!("call_{}", uuid::Uuid::new_v4()));

            Ok(Some(ToolCall {
                id,
                r#type: "function".to_string(),
                function: FunctionCall {
                    name: name.to_string(),
                    arguments,
                },
            }))
        } else {
            Ok(None)
        }
    }

    /// Parse JSON value(s) into tool calls
    fn parse_json_value(&self, value: &Value) -> ToolParserResult<Vec<ToolCall>> {
        let mut tools = Vec::new();

        match value {
            Value::Array(arr) => {
                // Parse each element in the array
                for item in arr {
                    if let Some(tool) = self.parse_single_object(item)? {
                        tools.push(tool);
                    }
                }
            }
            Value::Object(_) => {
                // Single tool call
                if let Some(tool) = self.parse_single_object(value)? {
                    tools.push(tool);
                }
            }
            _ => {
                // Not a valid tool call format
                return Ok(vec![]);
            }
        }

        Ok(tools)
    }

    /// Check if text contains potential tool call markers
    fn has_tool_markers(&self, text: &str) -> bool {
        // If no start tokens configured, check for JSON structure
        if self.start_tokens.is_empty() {
            // For JSON, we just need to see the start of an object or array
            return text.contains('{') || text.contains('[');
        }

        // Check for any start token
        self.start_tokens.iter().any(|token| text.contains(token))
    }
}

impl Default for JsonParser {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl ToolParser for JsonParser {
    async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
        // Extract JSON content from wrapper tokens if present
        let json_content = self.extract_json_content(text);

        // Try to parse as JSON
        match serde_json::from_str::<Value>(json_content) {
            Ok(value) => self.parse_json_value(&value),
            Err(_) => {
                // Not valid JSON, return empty
                Ok(vec![])
            }
        }
    }

    async fn parse_incremental(
        &self,
        chunk: &str,
        state: &mut ParseState,
    ) -> ToolParserResult<StreamResult> {
        state.buffer.push_str(chunk);

        // Check if we have potential tool calls
        if !self.has_tool_markers(&state.buffer) {
            // No tool markers, return as incomplete
            return Ok(StreamResult::Incomplete);
        }

        // Extract JSON content
        let json_content = self.extract_json_content(&state.buffer);

        // Try to parse with partial JSON parser
        match self.partial_json.parse_value(json_content) {
            Ok((value, consumed)) => {
                // Check if we have a complete JSON structure
                if consumed == json_content.len() {
                    // Complete JSON, parse tool calls
                    let tools = self.parse_json_value(&value)?;
                    if !tools.is_empty() {
                        // Clear buffer since we consumed everything
                        state.buffer.clear();

                        // Return the first tool as complete (simplified for Phase 2)
                        if let Some(tool) = tools.into_iter().next() {
                            return Ok(StreamResult::ToolComplete(tool));
                        }
                    }
                } else {
                    // Partial JSON, try to extract tool name
                    if let Some(name) = value.get("name").and_then(|v| v.as_str()) {
                        // Simple implementation for Phase 2
                        // Just return the tool name once we see it
                        if !state.in_string {
                            state.in_string = true; // Use as a flag for "name sent"
                            return Ok(StreamResult::ToolName {
                                index: 0,
                                name: name.to_string(),
                            });
                        }

                        // Check for complete arguments
                        if let Some(args) =
                            value.get("arguments").or_else(|| value.get("parameters"))
                        {
                            if let Ok(args_str) = serde_json::to_string(args) {
                                // Return arguments as a single update
                                return Ok(StreamResult::ToolArguments {
                                    index: 0,
                                    arguments: args_str,
                                });
                            }
                        }
                    }
                }
            }
            Err(_) => {
                // Failed to parse even as partial JSON
                // Keep buffering
            }
        }

        Ok(StreamResult::Incomplete)
    }

    fn detect_format(&self, text: &str) -> bool {
        // Check if text contains JSON-like structure
        if self.has_tool_markers(text) {
            // Try to extract and parse
            let json_content = self.extract_json_content(text);

            // Check if it looks like valid JSON for tool calls
            if let Ok(value) = serde_json::from_str::<Value>(json_content) {
                match value {
                    Value::Object(ref obj) => {
                        // Check for tool call structure
                        obj.contains_key("name") || obj.contains_key("function")
                    }
                    Value::Array(ref arr) => {
                        // Check if array contains tool-like objects
                        arr.iter().any(|v| {
                            v.as_object().is_some_and(|o| {
                                o.contains_key("name") || o.contains_key("function")
                            })
                        })
                    }
                    _ => false,
                }
            } else {
                false
            }
        } else {
            false
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_parse_single_tool_call() {
        let parser = JsonParser::new();
        let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#;

        let result = parser.parse_complete(input).await.unwrap();
        assert_eq!(result.len(), 1);
        assert_eq!(result[0].function.name, "get_weather");
    }

    #[tokio::test]
    async fn test_parse_multiple_tool_calls() {
        let parser = JsonParser::new();
        let input = r#"[
            {"name": "get_weather", "arguments": {"location": "SF"}},
            {"name": "search", "arguments": {"query": "news"}}
        ]"#;

        let result = parser.parse_complete(input).await.unwrap();
        assert_eq!(result.len(), 2);
        assert_eq!(result[0].function.name, "get_weather");
        assert_eq!(result[1].function.name, "search");
    }

    #[tokio::test]
    async fn test_parse_with_parameters_key() {
        let parser = JsonParser::new();
        let input = r#"{"name": "calculate", "parameters": {"x": 10, "y": 20}}"#;

        let result = parser.parse_complete(input).await.unwrap();
        assert_eq!(result.len(), 1);
        assert_eq!(result[0].function.name, "calculate");
        assert!(result[0].function.arguments.contains("10"));
    }

    #[tokio::test]
    async fn test_parse_with_wrapper_tokens() {
        let parser = JsonParser::with_config(
            vec!["<tool>".to_string()],
            vec!["</tool>".to_string()],
            ", ".to_string(),
        );

        let input = r#"<tool>{"name": "test", "arguments": {}}</tool>"#;
        let result = parser.parse_complete(input).await.unwrap();
        assert_eq!(result.len(), 1);
        assert_eq!(result[0].function.name, "test");
    }

    #[test]
    fn test_detect_format() {
        let parser = JsonParser::new();

        assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
        assert!(parser.detect_format(r#"[{"name": "test"}]"#));
        assert!(!parser.detect_format("plain text"));
        assert!(!parser.detect_format(r#"{"key": "value"}"#));
    }

    #[tokio::test]
    async fn test_streaming_parse() {
        // Phase 2 simplified streaming test
        // Just verify that streaming eventually produces a complete tool call
        let parser = JsonParser::new();
        let mut state = ParseState::new();

        // Send complete JSON in one go (simplified for Phase 2)
        let full_json = r#"{"name": "get_weather", "arguments": {"location": "SF"}}"#;

        let result = parser
            .parse_incremental(full_json, &mut state)
            .await
            .unwrap();

        // Should get a complete tool immediately with complete JSON
        match result {
            StreamResult::ToolComplete(tool) => {
                assert_eq!(tool.function.name, "get_weather");
                assert!(tool.function.arguments.contains("SF"));
            }
            _ => panic!("Expected ToolComplete for complete JSON input"),
        }
    }
}