"lib/bindings/vscode:/vscode.git/clone" did not exist on "30610e73716aad10c6189691365bf8194cc92b2d"
json_parser.rs 9.38 KB
Newer Older
1
2
3
4
5
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::collections::HashMap;

6
use regex::RegexBuilder;
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
use serde_json::Value;
use uuid::Uuid;

use super::parsers::JsonParserConfig;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType};

// Same as CalledFunction with named parameters
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct CalledFunctionParameters {
    pub name: String,
    pub parameters: HashMap<String, Value>,
}

// Same as CalledFunction with named parameters
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct CalledFunctionArguments {
    pub name: String,
    pub arguments: HashMap<String, Value>,
}

27
28
29
// Extract the contents between start and end tokens using regex parsing.
// Returns a JSON array string if there are multiple matches, otherwise returns the last match directly.
fn extract_tool_call_content(input: &str, start_token: &str, end_token: &str) -> Option<String> {
30
31
32
33
34
35
36
37
38
    let escaped_start = regex::escape(start_token);
    let escaped_end = regex::escape(end_token);
    let pattern = format!(r"{}(.*?){}", escaped_start, escaped_end);

    match RegexBuilder::new(&pattern)
        .dot_matches_new_line(true)
        .build()
    {
        Ok(regex) => {
39
            // Get all matches and take the last one for now. TODO: Handle multiple tool calls
40
41
42
            let matches: Vec<_> = regex
                .captures_iter(input)
                .filter_map(|captures| captures.get(1))
43
                .map(|m| m.as_str().trim().to_string())
44
                .collect();
45
46
47
48
49
50
51
52
53
54
55
            if !matches.is_empty() {
                // If only one match, return it directly, otherwise return as a JSON array string
                if matches.len() == 1 {
                    // Return the last match directly
                    return Some(matches.last().unwrap().clone());
                } else {
                    // Join the matches into a JSON array string
                    return Some(format!("[{}]", matches.join(",")));
                }
            }
            None
56
57
58
59
60
        }
        Err(_) => None,
    }
}

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
// Special case for <|python_tag|> . Regex pattern does not work well with it as it has no end token
// Handles single tool and multiple tool call cases for single start_token like <|python_tag|>
fn handle_single_token_tool_calls(input: &str, start_token: &str) -> String {
    // Return the input if it doesn't contain the start token
    if !input.contains(start_token) {
        return input.to_string();
    }

    // Split on the start token and keep only JSON-looking segments
    let mut items: Vec<String> = Vec::new();
    for seg in input.split(start_token) {
        let s = seg.trim();
        if s.is_empty() {
            continue;
        }
        // Only consider segments that start like JSON
77
        if s.starts_with('{') {
78
            // Trim trailing non-JSON by cutting at the last closing brace/bracket
79
80
            if let Some(pos) = s.rfind('}') {
                let candidate = &s[..=pos].trim();
81
82
83
84
85
86
87
88
                // Keep only valid JSON candidates
                if serde_json::from_str::<serde_json::Value>(candidate).is_ok() {
                    items.push(candidate.to_string());
                }
            }
        }
    }
    if items.is_empty() {
89
90
91
92
93
94
95
96
        // Remove everything up to and including the first occurrence of the start token
        if let Some(idx) = input.find(start_token) {
            let rest = &input[idx + start_token.len()..];
            return rest.trim_start().to_string();
        } else {
            // Shouldn't happen because we checked contains() above, but be defensive
            return input.to_string();
        }
97
98
99
100
    }
    format!("[{}]", items.join(","))
}

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
/// Attempts to parse a tool call from a raw LLM message string into a unified [`ToolCallResponse`] format.
///
/// This is a flexible helper that handles a variety of potential formats emitted by LLMs for function/tool calls,
/// including wrapped payloads (`<TOOLCALL>[...]</TOOLCALL>`, `<|python_tag|>...`) and JSON representations
/// with either `parameters` or `arguments` fields.
///
/// # Supported Formats
///
/// The input `message` may be one of:
///
/// - `<TOOLCALL>[{ "name": ..., "parameters": { ... } }]</TOOLCALL>`
/// - `<|python_tag|>{ "name": ..., "arguments": { ... } }`
/// - Raw JSON of:
///     - `CalledFunctionParameters`: `{ "name": ..., "parameters": { ... } }`
///     - `CalledFunctionArguments`: `{ "name": ..., "arguments": { ... } }`
///     - Or a list of either of those types: `[ { "name": ..., "arguments": { ... } }, ... ]`
///
/// # Return
///
/// - `Ok(Some(ToolCallResponse))` if parsing succeeds
/// - `Ok(None)` if input format is unrecognized or invalid JSON
/// - `Err(...)` if JSON is valid but deserialization or argument re-serialization fails
///
/// # Note on List Handling
///
/// When the input contains a list of tool calls (either with `parameters` or `arguments`),
/// only the **last item** in the list is returned. This design choice assumes that the
/// most recent tool call in a list is the one to execute.
///
/// # Errors
///
/// Returns a `Result::Err` only if an inner `serde_json::to_string(...)` fails
/// (e.g., if the arguments are not serializable).
///
/// # Examples
///
/// ```ignore
/// let input = r#"<TOOLCALL>[{ "name": "search", "parameters": { "query": "rust" } }]</TOOLCALL>"#;
/// let result = try_tool_call_parse_json(input)?;
/// assert!(result.is_some());
/// ```
pub fn try_tool_call_parse_json(
    message: &str,
    config: &JsonParserConfig,
145
) -> anyhow::Result<Vec<ToolCallResponse>> {
146
147
148
149
    // Log the config we are using
    tracing::debug!("Using JSON parser config: {:?}", config);
    let trimmed = message.trim();

150
151
152
    // Use config to get tool call start and end token vectors, then use the first element for now
    let tool_call_start_tokens = &config.tool_call_start_tokens;
    let tool_call_end_tokens = &config.tool_call_end_tokens;
153

154
155
156
157
    assert!(
        tool_call_start_tokens.len() == tool_call_end_tokens.len(),
        "Tool call start and end tokens must have the same length"
    );
158

159
    // Iterate over all start and end tokens and try to extract the content between them
160
161
    // Assumption : One message will not contain different tags for tool calls. Iteration over tags is to support different tags by default for multiple models
    let mut json = trimmed.to_string();
162
163
164
165
166
167
    for (start_token, end_token) in tool_call_start_tokens
        .iter()
        .zip(tool_call_end_tokens.iter())
    {
        // Special case for <|python_tag|> . Regex pattern does not work well with it as it has no end token
        json = if !start_token.is_empty() && end_token.is_empty() {
168
169
            handle_single_token_tool_calls(&json, start_token)
        } else if let Some(content) = extract_tool_call_content(&json, start_token, end_token) {
170
171
172
173
174
            content
        } else {
            json
        };
    }
175

176
    // Convert json (String) to &str
177
178
    let json = json.as_str();

179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    // Anonymous function to attempt deserialization into a known representation
    let parse = |name: String, args: HashMap<String, Value>| -> anyhow::Result<_> {
        Ok(ToolCallResponse {
            id: format!("call-{}", Uuid::new_v4()),
            tp: ToolCallType::Function,
            function: CalledFunction {
                name,
                arguments: serde_json::to_string(&args)?,
            },
        })
    };

    // CalledFunctionParameters: Single { name, parameters }
    // Example:
    // {
    //   "name": "search_docs",
    //   "parameters": {
    //     "query": "how to use Rust",
    //     "limit": 5
    //   }
    // }
    if let Ok(single) = serde_json::from_str::<CalledFunctionParameters>(json) {
201
202
203
204
205
206
207
208
209
210
211
212
        return Ok(vec![parse(single.name, single.parameters)?]);
        //parse(single.name, single.parameters).map(Some);

        // CalledFunctionArguments: Single { name, arguments }
        // Example:
        // {
        //   "name": "summarize",
        //   "arguments": {
        //     "text": "Rust is a systems programming language.",
        //     "length": "short"
        //   }
        // }
213
    } else if let Ok(single) = serde_json::from_str::<CalledFunctionArguments>(json) {
214
        return Ok(vec![parse(single.name, single.arguments)?]);
215
216
217
218
219
220
221
222

    // Vec<CalledFunctionParameters>: List of { name, parameters }
    // Example:
    // [
    //   { "name": "lookup_user", "parameters": { "user_id": "123" } },
    //   { "name": "send_email", "parameters": { "to": "user@example.com", "subject": "Welcome!" } }
    // ]
    // We pop the last item in the list to use.
223
224
225
226
    } else if let Ok(list) = serde_json::from_str::<Vec<CalledFunctionParameters>>(json) {
        let mut results = Vec::new();
        for item in list {
            results.push(parse(item.name, item.parameters)?);
227
        }
228
        return Ok(results);
229
230
231
232
233
234
235
236
237
238
239
240
241

    // Vec<CalledFunctionArguments>: List of { name, arguments }
    // Example:
    // [
    //   {
    //     "name": "get_weather",
    //     "arguments": {
    //       "location": "San Francisco",
    //       "units": "celsius"
    //     }
    //   }
    // ]
    // Again, we take the last item for processing.
242
243
244
245
    } else if let Ok(list) = serde_json::from_str::<Vec<CalledFunctionArguments>>(json) {
        let mut results = Vec::new();
        for item in list {
            results.push(parse(item.name, item.arguments)?);
246
        }
247
        return Ok(results);
248
249
    }

250
    Ok(vec![])
251
}