helpers.rs 15.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
use crate::protocols::spec::Tool;
use serde_json::Value;
use std::collections::HashMap;

use crate::tool_parser::errors::{ToolParserError, ToolParserResult};
use crate::tool_parser::types::{StreamingParseResult, ToolCallItem};

/// Get a mapping of tool names to their indices
pub fn get_tool_indices(tools: &[Tool]) -> HashMap<String, usize> {
    tools
        .iter()
        .enumerate()
        .map(|(i, tool)| (tool.function.name.clone(), i))
        .collect()
}

17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
/// Get unstreamed tool call arguments
/// Returns tool call items for arguments that have been parsed but not yet streamed
/// This ensures tool calls are properly completed even if the model generates final arguments in the last chunk
pub fn get_unstreamed_args(
    prev_tool_call_arr: &[Value],
    streamed_args_for_tool: &[String],
) -> Option<Vec<ToolCallItem>> {
    // Check if we have tool calls being tracked
    if prev_tool_call_arr.is_empty() || streamed_args_for_tool.is_empty() {
        return None;
    }

    // Get the last tool call that was being processed
    let tool_index = prev_tool_call_arr.len() - 1;
    if tool_index >= streamed_args_for_tool.len() {
        return None;
    }

    // Get expected vs actual arguments
    let expected_args = prev_tool_call_arr[tool_index].get("arguments")?;
    let expected_str = serde_json::to_string(expected_args).ok()?;
    let actual_str = &streamed_args_for_tool[tool_index];

    // Check if there are remaining arguments to send
    let remaining = if expected_str.starts_with(actual_str) {
        &expected_str[actual_str.len()..]
    } else {
        return None;
    };

    if remaining.is_empty() {
        return None;
    }

    // Return the remaining arguments as a ToolCallItem
    Some(vec![ToolCallItem {
        tool_index,
        name: None, // No name for argument deltas
        parameters: remaining.to_string(),
    }])
}

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
/// Check if a buffer ends with a partial occurrence of a token
/// Returns Some(length) if there's a partial match, None otherwise
pub fn ends_with_partial_token(buffer: &str, token: &str) -> Option<usize> {
    if buffer.is_empty() || token.is_empty() {
        return None;
    }

    (1..token.len()).find(|&i| buffer.ends_with(&token[..i]))
}

/// Reset state for the current tool being parsed (used when skipping invalid tools).
/// This preserves the parser's overall state (current_tool_id, prev_tool_call_arr)
/// but clears the state specific to the current incomplete tool.
pub fn reset_current_tool_state(
    buffer: &mut String,
    current_tool_name_sent: &mut bool,
    streamed_args_for_tool: &mut Vec<String>,
    prev_tool_call_arr: &[Value],
) {
    buffer.clear();
    *current_tool_name_sent = false;

    // Only pop if we added an entry for the current (invalid) tool
    // streamed_args_for_tool should match prev_tool_call_arr length for completed tools
    if streamed_args_for_tool.len() > prev_tool_call_arr.len() {
        streamed_args_for_tool.pop();
    }
}

/// Reset the entire parser state (used at the start of a new request).
/// Clears all accumulated tool calls and resets all state to initial values.
pub fn reset_parser_state(
    buffer: &mut String,
    prev_tool_call_arr: &mut Vec<Value>,
    current_tool_id: &mut i32,
    current_tool_name_sent: &mut bool,
    streamed_args_for_tool: &mut Vec<String>,
) {
    buffer.clear();
    prev_tool_call_arr.clear();
    *current_tool_id = 0;
    *current_tool_name_sent = false;
    streamed_args_for_tool.clear();
}

/// Ensure arrays have capacity for the given tool ID
pub fn ensure_capacity(
    current_tool_id: i32,
    prev_tool_call_arr: &mut Vec<Value>,
    streamed_args_for_tool: &mut Vec<String>,
) {
    if current_tool_id < 0 {
        return;
    }
    let needed = (current_tool_id + 1) as usize;

    if prev_tool_call_arr.len() < needed {
        prev_tool_call_arr.resize_with(needed, || Value::Null);
    }
    if streamed_args_for_tool.len() < needed {
        streamed_args_for_tool.resize_with(needed, String::new);
    }
}

/// Check if a string contains complete, valid JSON
pub fn is_complete_json(input: &str) -> bool {
    serde_json::from_str::<Value>(input).is_ok()
}

/// Normalize the arguments/parameters field in a tool call object.
/// If the object has "parameters" but not "arguments", copy parameters to arguments.
///
/// # Background
/// Different LLM formats use different field names:
/// - Llama and JSON parsers use "parameters" (correct per JSON Schema spec)
/// - Mistral and Qwen use "arguments"
///
/// This function normalizes to "arguments" for consistent downstream processing.
pub fn normalize_arguments_field(mut obj: Value) -> Value {
    if obj.get("arguments").is_none() {
        if let Some(params) = obj.get("parameters").cloned() {
            if let Value::Object(ref mut map) = obj {
                map.insert("arguments".to_string(), params);
            }
        }
    }
    obj
}

/// Handle the entire JSON tool call streaming process for JSON-based parsers.
///
/// This unified function handles all aspects of streaming tool calls:
/// - Parsing partial JSON from the buffer
/// - Validating tool names against available tools
/// - Streaming tool names (Case 1)
/// - Streaming tool arguments (Case 2)
/// - Managing parser state and buffer updates
///
/// Used by JSON, Llama, Mistral, and Qwen parsers.
///
/// # Parameters
/// - `current_text`: The current buffered text being parsed
/// - `start_idx`: Start index of JSON content in current_text
/// - `partial_json`: Mutable reference to partial JSON parser
/// - `tool_indices`: Map of valid tool names to their indices
/// - `buffer`: Mutable parser buffer
/// - `current_tool_id`: Mutable current tool index (-1 means no active tool)
/// - `current_tool_name_sent`: Mutable flag for whether current tool's name was sent
/// - `streamed_args_for_tool`: Mutable accumulator of streamed arguments per tool
/// - `prev_tool_call_arr`: Mutable array of previous tool call states
///
/// # Returns
/// - `Ok(StreamingParseResult)` with any tool call items to stream
/// - `Err(ToolParserError)` if JSON parsing or serialization fails
#[allow(clippy::too_many_arguments)]
pub fn handle_json_tool_streaming(
    current_text: &str,
    start_idx: usize,
    partial_json: &mut crate::tool_parser::partial_json::PartialJson,
    tool_indices: &HashMap<String, usize>,
    buffer: &mut String,
    current_tool_id: &mut i32,
    current_tool_name_sent: &mut bool,
    streamed_args_for_tool: &mut Vec<String>,
    prev_tool_call_arr: &mut Vec<Value>,
) -> ToolParserResult<StreamingParseResult> {
    // Check if we have content to parse
    if start_idx >= current_text.len() {
        return Ok(StreamingParseResult::default());
    }

    // Extract JSON string from current position
    let json_str = &current_text[start_idx..];

    // Parse partial JSON
    let (obj, end_idx) = match partial_json.parse_value(json_str) {
        Ok(result) => result,
        Err(_) => {
            return Ok(StreamingParseResult::default());
        }
    };

    // Check if JSON is complete
    let is_complete = end_idx == json_str.len() && serde_json::from_str::<Value>(json_str).is_ok();

    // Validate tool name if present
    if let Some(name) = obj.get("name").and_then(|v| v.as_str()) {
        if !tool_indices.contains_key(name) {
            // Invalid tool name - skip this tool, preserve indexing for next tool
            tracing::warn!("Invalid tool name '{}' - skipping", name);
            reset_current_tool_state(
                buffer,
                current_tool_name_sent,
                streamed_args_for_tool,
                prev_tool_call_arr,
            );
            return Ok(StreamingParseResult::default());
        }
    }

    // Normalize parameters/arguments field
    let current_tool_call = normalize_arguments_field(obj);

    let mut result = StreamingParseResult::default();

    // Case 1: Handle tool name streaming
    if !*current_tool_name_sent {
        if let Some(function_name) = current_tool_call.get("name").and_then(|v| v.as_str()) {
            if tool_indices.contains_key(function_name) {
                // Initialize if first tool
                if *current_tool_id == -1 {
                    *current_tool_id = 0;
                    streamed_args_for_tool.push(String::new());
                } else if *current_tool_id as usize >= streamed_args_for_tool.len() {
                    // Ensure capacity for subsequent tools
                    ensure_capacity(*current_tool_id, prev_tool_call_arr, streamed_args_for_tool);
                }

                // Send tool name with empty parameters
                *current_tool_name_sent = true;
                result.calls.push(ToolCallItem {
                    tool_index: *current_tool_id as usize,
                    name: Some(function_name.to_string()),
                    parameters: String::new(),
                });
            }
        }
    }
    // Case 2: Handle streaming arguments
    else if let Some(cur_arguments) = current_tool_call.get("arguments") {
        let tool_id = *current_tool_id as usize;
        let sent = streamed_args_for_tool
            .get(tool_id)
            .map(|s| s.len())
            .unwrap_or(0);
        let cur_args_json = serde_json::to_string(cur_arguments)
            .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;

        // Compute diff: everything after what we've already sent
        let diff = cur_args_json[sent..].to_string();

        // Send diff if there's new content
        if !diff.is_empty() {
            // Only accumulate if not complete
            if !is_complete && tool_id < streamed_args_for_tool.len() {
                streamed_args_for_tool[tool_id].push_str(&diff);
            }

            result.calls.push(ToolCallItem {
                tool_index: tool_id,
                name: None,
                parameters: diff,
            });
        }

        // If JSON is complete, advance to next tool
        if is_complete {
            // Remove processed portion, keep unprocessed content
            *buffer = current_text[start_idx + end_idx..].to_string();

            // Clear completed tool data
            if tool_id < prev_tool_call_arr.len() {
                prev_tool_call_arr[tool_id] = Value::Null;
            }
            *current_tool_name_sent = false;
            if tool_id < streamed_args_for_tool.len() {
                streamed_args_for_tool[tool_id].clear();
            }
            *current_tool_id += 1;
        }
    }

    // Update prev_tool_call_arr with current state
    if *current_tool_id >= 0 {
        ensure_capacity(*current_tool_id, prev_tool_call_arr, streamed_args_for_tool);
        let tool_id = *current_tool_id as usize;

        if tool_id < prev_tool_call_arr.len() {
            prev_tool_call_arr[tool_id] = current_tool_call;
        }
    }

    Ok(result)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_ends_with_partial_token() {
        assert!(ends_with_partial_token("hello <|py", "<|python_tag|>").is_some());
        assert!(ends_with_partial_token("hello <|python_tag", "<|python_tag|>").is_some());
        assert!(ends_with_partial_token("hello <|python_tag|>", "<|python_tag|>").is_none());
        assert!(ends_with_partial_token("", "<|python_tag|>").is_none());
        assert!(ends_with_partial_token("hello world", "<|python_tag|>").is_none());
    }

    #[test]
    fn test_reset_current_tool_state() {
        let mut buffer = String::from("partial json");
        let mut current_tool_name_sent = true;
        let mut streamed_args = vec!["tool0_args".to_string(), "tool1_partial".to_string()];
        let prev_tools = vec![serde_json::json!({"name": "tool0"})];

        reset_current_tool_state(
            &mut buffer,
            &mut current_tool_name_sent,
            &mut streamed_args,
            &prev_tools,
        );

        assert_eq!(buffer, "");
        assert!(!current_tool_name_sent);
        assert_eq!(streamed_args.len(), 1); // Popped the partial tool1 args
        assert_eq!(streamed_args[0], "tool0_args");
    }

    #[test]
    fn test_reset_current_tool_state_no_pop_when_synced() {
        let mut buffer = String::from("partial json");
        let mut current_tool_name_sent = true;
        let mut streamed_args = vec!["tool0_args".to_string()];
        let prev_tools = vec![serde_json::json!({"name": "tool0"})];

        reset_current_tool_state(
            &mut buffer,
            &mut current_tool_name_sent,
            &mut streamed_args,
            &prev_tools,
        );

        assert_eq!(buffer, "");
        assert!(!current_tool_name_sent);
        assert_eq!(streamed_args.len(), 1); // No pop, lengths matched
    }

    #[test]
    fn test_reset_parser_state() {
        let mut buffer = String::from("some buffer");
        let mut prev_tools = vec![serde_json::json!({"name": "tool0"})];
        let mut current_tool_id = 5;
        let mut current_tool_name_sent = true;
        let mut streamed_args = vec!["args".to_string()];

        reset_parser_state(
            &mut buffer,
            &mut prev_tools,
            &mut current_tool_id,
            &mut current_tool_name_sent,
            &mut streamed_args,
        );

        assert_eq!(buffer, "");
        assert_eq!(prev_tools.len(), 0);
        assert_eq!(current_tool_id, 0);
        assert!(!current_tool_name_sent);
        assert_eq!(streamed_args.len(), 0);
    }

    #[test]
    fn test_ensure_capacity() {
        let mut prev_tools = vec![];
        let mut streamed_args = vec![];

        ensure_capacity(2, &mut prev_tools, &mut streamed_args);

        assert_eq!(prev_tools.len(), 3);
        assert_eq!(streamed_args.len(), 3);
        assert_eq!(prev_tools[0], Value::Null);
        assert_eq!(streamed_args[0], "");
    }

    #[test]
    fn test_ensure_capacity_negative_id() {
        let mut prev_tools = vec![];
        let mut streamed_args = vec![];

        ensure_capacity(-1, &mut prev_tools, &mut streamed_args);

        // Should not resize for negative ID
        assert_eq!(prev_tools.len(), 0);
        assert_eq!(streamed_args.len(), 0);
    }

    #[test]
    fn test_is_complete_json() {
        assert!(is_complete_json(r#"{"name": "test"}"#));
        assert!(is_complete_json("[1, 2, 3]"));
        assert!(is_complete_json("42"));
        assert!(is_complete_json("true"));
        assert!(!is_complete_json(r#"{"name": "#));
        assert!(!is_complete_json("[1, 2,"));
    }

    #[test]
    fn test_normalize_arguments_field() {
        // Case 1: Has parameters, no arguments
        let obj = serde_json::json!({
            "name": "test",
            "parameters": {"key": "value"}
        });
        let normalized = normalize_arguments_field(obj);
        assert_eq!(
            normalized.get("arguments").unwrap(),
            &serde_json::json!({"key": "value"})
        );

        // Case 2: Already has arguments
        let obj = serde_json::json!({
            "name": "test",
            "arguments": {"key": "value"}
        });
        let normalized = normalize_arguments_field(obj.clone());
        assert_eq!(normalized, obj);

        // Case 3: No parameters or arguments
        let obj = serde_json::json!({"name": "test"});
        let normalized = normalize_arguments_field(obj.clone());
        assert_eq!(normalized, obj);
    }
}