json.rs 10.4 KB
Newer Older
1
2
3
use async_trait::async_trait;
use serde_json::Value;

4
5
6
7
8
9
10
11
12
use crate::{
    protocols::common::Tool,
    tool_parser::{
        errors::{ParserError, ParserResult},
        parsers::helpers,
        partial_json::PartialJson,
        traits::ToolParser,
        types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
    },
13
14
15
16
};

/// JSON format parser for tool calls
///
17
/// Handles pure JSON formats for function calling:
18
19
20
21
22
23
/// - Single tool call: {"name": "fn", "arguments": {...}}
/// - Multiple tool calls: [{"name": "fn1", "arguments": {...}}, ...]
/// - With parameters instead of arguments: {"name": "fn", "parameters": {...}}
pub struct JsonParser {
    /// Parser for handling incomplete JSON during streaming
    partial_json: PartialJson,
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

    /// Buffer for accumulating incomplete patterns across chunks
    buffer: String,

    /// Stores complete tool call info (name and arguments) for each tool being parsed
    prev_tool_call_arr: Vec<Value>,

    /// Index of currently streaming tool call (-1 means no active tool)
    current_tool_id: i32,

    /// Flag for whether current tool's name has been sent to client
    current_tool_name_sent: bool,

    /// Tracks raw JSON string content streamed to client for each tool's arguments
    streamed_args_for_tool: Vec<String>,

    /// Separator between multiple tool calls
    tool_call_separator: &'static str,
42
43
44
45
46
47

    /// Track whether we're parsing array format `[...]` vs single object `{...}`
    is_array_format: bool,

    /// Track whether we've already stripped the closing ] bracket (for array format)
    array_closed: bool,
48
49
50
}

impl JsonParser {
51
    /// Create a new JSON parser
52
53
54
    pub fn new() -> Self {
        Self {
            partial_json: PartialJson::default(),
55
56
57
58
59
60
            buffer: String::new(),
            prev_tool_call_arr: Vec::new(),
            current_tool_id: -1,
            current_tool_name_sent: false,
            streamed_args_for_tool: Vec::new(),
            tool_call_separator: ",",
61
62
            is_array_format: false,
            array_closed: false,
63
64
65
        }
    }

66
67
68
69
70
71
72
73
74
75
76
77
78
    /// Try to extract a first valid JSON object or array from text that may contain other content
    /// Returns (json_string, normal_text) where normal_text is text before and after the JSON
    fn extract_json_from_text(&self, text: &str) -> Option<(String, String)> {
        let mut in_string = false;
        let mut escape = false;
        let mut stack: Vec<char> = Vec::with_capacity(8);
        let mut start: Option<usize> = None;

        for (i, ch) in text.char_indices() {
            if escape {
                escape = false;
                continue;
            }
79

80
81
82
83
84
85
86
            match ch {
                '\\' if in_string => escape = true,
                '"' => in_string = !in_string,
                _ if in_string => {}
                '{' | '[' => {
                    if start.is_none() {
                        start = Some(i);
87
                    }
88
                    stack.push(ch);
89
                }
90
91
92
93
94
95
                '}' | ']' => {
                    let Some(open) = stack.pop() else {
                        // Stray closer - reset and continue looking for next valid JSON
                        start = None;
                        continue;
                    };
96

97
98
99
100
101
102
103
                    let valid = (open == '{' && ch == '}') || (open == '[' && ch == ']');
                    if !valid {
                        // Mismatch - reset and continue looking
                        start = None;
                        stack.clear();
                        continue;
                    }
104

105
106
107
108
109
110
111
112
113
114
115
116
117
118
                    if stack.is_empty() {
                        let s = start.unwrap();
                        let e = i + ch.len_utf8();
                        let potential_json = &text[s..e];

                        // Validate that this is actually valid JSON before returning
                        if serde_json::from_str::<Value>(potential_json).is_ok() {
                            let json = potential_json.to_string();
                            let normal = format!("{}{}", &text[..s], &text[e..]);
                            return Some((json, normal));
                        } else {
                            // Not valid JSON, reset and continue looking
                            start = None;
                            continue;
119
120
121
                        }
                    }
                }
122
                _ => {}
123
124
125
            }
        }
        None
126
127
128
    }

    /// Parse a single JSON object into a ToolCall
129
    fn parse_single_object(&self, obj: &Value) -> ParserResult<Option<ToolCall>> {
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        // Check if this looks like a tool call
        let name = obj
            .get("name")
            .or_else(|| obj.get("function"))
            .and_then(|v| v.as_str());

        if let Some(name) = name {
            // Get arguments - support both "arguments" and "parameters" keys
            let empty_obj = Value::Object(serde_json::Map::new());
            let args = obj
                .get("arguments")
                .or_else(|| obj.get("parameters"))
                .unwrap_or(&empty_obj);

            // Convert arguments to JSON string
            let arguments = serde_json::to_string(args)
146
                .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
147
148
149
150
151
152
153
154
155
156
157
158
159

            Ok(Some(ToolCall {
                function: FunctionCall {
                    name: name.to_string(),
                    arguments,
                },
            }))
        } else {
            Ok(None)
        }
    }

    /// Parse JSON value(s) into tool calls
160
    fn parse_json_value(&self, value: &Value) -> ParserResult<Vec<ToolCall>> {
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        let mut tools = Vec::new();

        match value {
            Value::Array(arr) => {
                // Parse each element in the array
                for item in arr {
                    if let Some(tool) = self.parse_single_object(item)? {
                        tools.push(tool);
                    }
                }
            }
            Value::Object(_) => {
                // Single tool call
                if let Some(tool) = self.parse_single_object(value)? {
                    tools.push(tool);
                }
            }
            _ => {
                // Not a valid tool call format
                return Ok(vec![]);
            }
        }

        Ok(tools)
    }
}

impl Default for JsonParser {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl ToolParser for JsonParser {
196
    async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
197
198
199
        // Always use extract_json_from_text to handle both pure JSON and mixed content
        if let Some((extracted_json, normal_text)) = self.extract_json_from_text(text) {
            let parsed = serde_json::from_str::<Value>(&extracted_json)
200
                .map_err(|e| ParserError::ParsingFailed(e.to_string()))
201
202
203
204
205
                .and_then(|v| self.parse_json_value(&v));

            match parsed {
                Ok(tools) => return Ok((normal_text, tools)),
                Err(e) => tracing::warn!("parse_complete failed: {:?}", e),
206
207
208
            }
        }

209
210
        // No valid JSON found, return original text as normal text
        Ok((text.to_string(), vec![]))
211
212
213
    }

    async fn parse_incremental(
214
        &mut self,
215
        chunk: &str,
216
        tools: &[Tool],
217
    ) -> ParserResult<StreamingParseResult> {
218
219
220
221
        // Append new text to buffer
        self.buffer.push_str(chunk);
        let current_text = &self.buffer.clone();

222
223
224
225
226
        // Determine format on first parse (array vs single object)
        if self.current_tool_id == -1 && self.has_tool_markers(current_text) {
            self.is_array_format = current_text.trim().starts_with('[');
        }

227
        // Check if current_text has tool_call
228
229
230
        // Once array is closed, don't treat [ or { as tool markers
        let has_tool_start = (!self.array_closed && self.has_tool_markers(current_text))
            || (self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator));
231
232

        if !has_tool_start {
233
            let mut normal_text = self.buffer.clone();
234
235
            self.buffer.clear();

236
237
238
239
240
241
242
243
244
245
246
            // Strip ] only once (the closing bracket of JSON array format)
            // Only for array format and only if we haven't already closed it
            if self.is_array_format
                && !self.array_closed
                && self.current_tool_id > 0
                && normal_text.starts_with("]")
            {
                normal_text = normal_text.strip_prefix("]").unwrap().to_string();
                self.array_closed = true;
            }

247
248
249
250
            return Ok(StreamingParseResult {
                normal_text,
                calls: vec![],
            });
251
252
        }

253
254
        // Build tool indices
        let tool_indices = helpers::get_tool_indices(tools);
255

256
257
258
259
260
        // Determine start index for JSON parsing
        // JSON can start with [ (array) or { (single object)
        let start_idx = if let Some(bracket_pos) = current_text.find('[') {
            let brace_pos = current_text.find('{');
            match brace_pos {
261
                Some(bp) => bp,
262
                _ => bracket_pos,
263
            }
264
265
        } else if let Some(brace_pos) = current_text.find('{') {
            brace_pos
266
        } else if self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator) {
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
            self.tool_call_separator.len()
        } else {
            0
        };

        helpers::handle_json_tool_streaming(
            current_text,
            start_idx,
            &mut self.partial_json,
            &tool_indices,
            &mut self.buffer,
            &mut self.current_tool_id,
            &mut self.current_tool_name_sent,
            &mut self.streamed_args_for_tool,
            &mut self.prev_tool_call_arr,
        )
283
284
    }

285
    fn has_tool_markers(&self, text: &str) -> bool {
286
        let trimmed = text.trim();
287
        trimmed.starts_with('[') || trimmed.starts_with('{')
288
    }
289
290
291
292

    fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
    }
293
294
295
296
297
298
299
300
301

    fn reset(&mut self) {
        helpers::reset_parser_state(
            &mut self.buffer,
            &mut self.prev_tool_call_arr,
            &mut self.current_tool_id,
            &mut self.current_tool_name_sent,
            &mut self.streamed_args_for_tool,
        );
302
303
        self.is_array_format = false;
        self.array_closed = false;
304
    }
305
}