mod.rs 5.83 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
// SPDX-License-Identifier: Apache-2.0

pub mod base_json_parser;
5
6
pub mod deepseek_v3_1_parser;
pub mod deepseek_v3_parser;
7
8

pub use super::{config, response};
9
pub use base_json_parser::{detect_tool_call_start_basic_json, try_tool_call_parse_basic_json};
10
11
12
13
pub use deepseek_v3_1_parser::{
    detect_tool_call_start_deepseek_v3_1, parse_tool_calls_deepseek_v3_1,
};
pub use deepseek_v3_parser::{detect_tool_call_start_deepseek_v3, parse_tool_calls_deepseek_v3};
14
15
16
17

pub use super::config::JsonParserConfig;
pub use super::response::ToolCallResponse;

18
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, Default)]
19
20
pub enum JsonParserType {
    // Basic is generic json parser which can handle most of the cases
21
    #[default]
22
23
    Basic,
    // Model Specific JSON Parsers
24
    DeepseekV3,
25
26
27
28
29
30
    DeepseekV31,
}

pub fn try_tool_call_parse_json(
    message: &str,
    config: &JsonParserConfig,
31
    tools: Option<&[super::ToolDefinition]>,
32
33
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
    match config.parser_type {
34
35
36
        JsonParserType::Basic => try_tool_call_parse_basic_json(message, config, tools),
        JsonParserType::DeepseekV3 => parse_tool_calls_deepseek_v3(message, config, tools),
        JsonParserType::DeepseekV31 => parse_tool_calls_deepseek_v3_1(message, config, tools),
37
38
    }
}
39
40
41
42

pub fn detect_tool_call_start_json(chunk: &str, config: &JsonParserConfig) -> bool {
    match config.parser_type {
        JsonParserType::Basic => detect_tool_call_start_basic_json(chunk, config),
43
        JsonParserType::DeepseekV3 => detect_tool_call_start_deepseek_v3(chunk, config),
44
45
46
        JsonParserType::DeepseekV31 => detect_tool_call_start_deepseek_v3_1(chunk, config),
    }
}
47
48
49
50
51
52
53
54

pub fn find_tool_call_end_position_json(
    chunk: &str,
    parser: &str,
    config: &JsonParserConfig,
) -> usize {
    match parser {
        "hermes" | "nemotron_deci" => {
55
            let start_token = config.tool_call_start_tokens.first().map(|s| s.as_str());
56
            if let Some(end_token) = config.tool_call_end_tokens.first() {
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
                let Some(first_end) = chunk.find(end_token.as_str()) else {
                    return chunk.len();
                };
                let mut cursor = first_end + end_token.len();

                // Advance past any additional consecutive start→end blocks
                // so that parallel tool calls are captured as one jailed region.
                if let Some(start_tok) = start_token {
                    loop {
                        let rest = &chunk[cursor..];
                        let trimmed = rest.trim_start();
                        if !trimmed.starts_with(start_tok) {
                            break;
                        }
                        let trim_offset = rest.len() - trimmed.len();
                        let search_from = cursor + trim_offset + start_tok.len();
                        if let Some(end_pos) = chunk[search_from..].find(end_token.as_str()) {
                            cursor = search_from + end_pos + end_token.len();
                        } else {
                            break;
                        }
                    }
79
                }
80
                cursor
81
82
83
84
85
86
87
88
89
90
91
92
93
94
            } else {
                chunk.len()
            }
        }
        "mistral" | "phi4" => {
            if let Some(pos) = chunk.rfind(']') {
                pos + 1
            } else {
                chunk.len()
            }
        }
        _ => chunk.len(),
    }
}
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152

#[cfg(test)]
mod tests {
    use super::*;

    /// Regression test for issue #6822: parallel tool calls in a single chunk must
    /// all be captured by find_tool_call_end_position_json so that the jail passes the
    /// entire group to the parser rather than emitting the second (and later) calls
    /// as raw trailing text.
    #[test]
    fn test_find_tool_call_end_position_parallel_calls() {
        let config = JsonParserConfig {
            tool_call_start_tokens: vec!["<tool_call>".to_string()],
            tool_call_end_tokens: vec!["</tool_call>".to_string()],
            ..Default::default()
        };

        // Two parallel calls with no whitespace between them.
        let two_calls = concat!(
            "<tool_call>{\"name\": \"foo\", \"arguments\": {\"x\": 1}}</tool_call>",
            "<tool_call>{\"name\": \"bar\", \"arguments\": {\"y\": 2}}</tool_call>",
            "trailing"
        );
        let pos = find_tool_call_end_position_json(two_calls, "hermes", &config);
        assert!(
            two_calls[..pos].ends_with("</tool_call>"),
            "should end at last </tool_call>, got: {:?}",
            &two_calls[..pos]
        );
        assert_eq!(&two_calls[pos..], "trailing");

        // Three parallel calls separated by whitespace / newlines.
        let three_calls = concat!(
            "<tool_call>{\"name\": \"a\"}</tool_call>\n",
            "<tool_call>{\"name\": \"b\"}</tool_call>\n",
            "<tool_call>{\"name\": \"c\"}</tool_call> done"
        );
        let pos3 = find_tool_call_end_position_json(three_calls, "hermes", &config);
        assert!(
            three_calls[..pos3].ends_with("</tool_call>"),
            "should end at last </tool_call>, got: {:?}",
            &three_calls[..pos3]
        );
        assert_eq!(three_calls[pos3..].trim(), "done");

        // Incomplete second call — should stop after the first complete one.
        let incomplete = concat!(
            "<tool_call>{\"name\": \"a\"}</tool_call>",
            "<tool_call>{\"name\": \"b\""
        );
        let pos_inc = find_tool_call_end_position_json(incomplete, "hermes", &config);
        let first_end = "<tool_call>{\"name\": \"a\"}</tool_call>".len();
        assert_eq!(
            pos_inc, first_end,
            "should stop at end of first complete call when second is incomplete"
        );
    }
}