config.rs 9.36 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use super::json::JsonParserType;

6
7
8
9
10
11
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct JsonParserConfig {
    /// Start token for individual tool calls (e.g., "<TOOLCALL>")
    pub tool_call_start_tokens: Vec<String>,
    /// End token for individual tool calls (e.g., "</TOOLCALL>")
    pub tool_call_end_tokens: Vec<String>,
12
13
14
15
    /// Separator tokens between function name and arguments
    /// (e.g., "<|tool▁sep|>" for DeepSeek v3.1)
    /// Used by some models to separate function name from arguments
    pub tool_call_separator_tokens: Vec<String>,
16
17
18
19
20
21
22
23
    /// The key for the function name in the tool call
    /// i.e. `{"name": "function", "arguments": {...}}` it would be
    /// "name"
    pub function_name_keys: Vec<String>,
    /// The key for the arguments in the tool call
    /// i.e. `{"name": "function", "arguments": {...}}` it would be
    /// "arguments"
    pub arguments_keys: Vec<String>,
24
25
26
27

    /// The type of JSON parser to use
    #[serde(default)]
    pub parser_type: JsonParserType,
28
29
30
31
32
33
34
}

impl Default for JsonParserConfig {
    fn default() -> Self {
        Self {
            tool_call_start_tokens: vec!["<TOOLCALL>".to_string(), "<|python_tag|>".to_string()],
            tool_call_end_tokens: vec!["</TOOLCALL>".to_string(), "".to_string()],
35
            tool_call_separator_tokens: vec![],
36
37
            function_name_keys: vec!["name".to_string()],
            arguments_keys: vec!["arguments".to_string(), "parameters".to_string()],
38
            parser_type: JsonParserType::Basic,
39
40
41
42
        }
    }
}

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct XmlParserConfig {
    /// Start token for individual tool calls (e.g., "<tool_call>")
    pub tool_call_start_token: String,
    /// End token for individual tool calls (e.g., "</tool_call>")
    pub tool_call_end_token: String,
    /// Start token for function name (e.g., "<function=")
    pub function_start_token: String,
    /// End token for function (e.g., "</function>")
    pub function_end_token: String,
    /// Start token for parameter (e.g., "<parameter=")
    pub parameter_start_token: String,
    /// End token for parameter (e.g., "</parameter>")
    pub parameter_end_token: String,
}

impl Default for XmlParserConfig {
    fn default() -> Self {
        Self {
            tool_call_start_token: "<tool_call>".to_string(),
            tool_call_end_token: "</tool_call>".to_string(),
            function_start_token: "<function=".to_string(),
            function_end_token: "</function>".to_string(),
            parameter_start_token: "<parameter=".to_string(),
            parameter_end_token: "</parameter>".to_string(),
        }
    }
}

/// Parser-specific configuration
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ParserConfig {
    Json(JsonParserConfig),
    Xml(XmlParserConfig),
    Pythonic,
    Harmony(JsonParserConfig),
    Typescript,
}

impl ParserConfig {
    /// Get the tool call start tokens for this parser configuration
    /// Returns a vector of start tokens that indicate the beginning of a tool call
    pub fn tool_call_start_tokens(&self) -> Vec<String> {
        match self {
            ParserConfig::Json(config) => config.tool_call_start_tokens.clone(),
            ParserConfig::Harmony(config) => config.tool_call_start_tokens.clone(),
            ParserConfig::Xml(config) => vec![config.tool_call_start_token.clone()],
            ParserConfig::Pythonic => vec![],
            ParserConfig::Typescript => vec![],
        }
    }

    /// Get the tool call end tokens for this parser configuration
    /// Returns a vector of end tokens that indicate the end of a tool call
    pub fn tool_call_end_tokens(&self) -> Vec<String> {
        match self {
            ParserConfig::Json(config) => config.tool_call_end_tokens.clone(),
            ParserConfig::Harmony(config) => config.tool_call_end_tokens.clone(),
            ParserConfig::Xml(config) => vec![config.tool_call_end_token.clone()],
            ParserConfig::Pythonic => vec![],
            ParserConfig::Typescript => vec![],
        }
    }
}

109
110
111
/// Configuration for parsing tool calls with different formats
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ToolCallConfig {
112
113
    /// Parser-specific configuration.
    pub parser_config: ParserConfig,
114
115
116
117
118
}

impl Default for ToolCallConfig {
    fn default() -> Self {
        Self {
119
            parser_config: ParserConfig::Json(JsonParserConfig::default()),
120
121
122
123
124
125
126
127
128
        }
    }
}

impl ToolCallConfig {
    /// Default configuration for hermes tool calls
    /// <tool_call>{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}\n</tool_call>
    pub fn hermes() -> Self {
        Self {
129
            parser_config: ParserConfig::Json(JsonParserConfig {
130
                tool_call_start_tokens: vec!["<tool_call>".to_string()],
131
                tool_call_end_tokens: vec!["</tool_call>".to_string()],
132
                ..Default::default()
133
            }),
134
135
136
137
138
139
140
        }
    }

    /// Default configuration for nemotron tool calls
    /// <TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]</TOOLCALL>
    pub fn nemotron_deci() -> Self {
        Self {
141
            parser_config: ParserConfig::Json(JsonParserConfig {
142
143
144
                tool_call_start_tokens: vec!["<TOOLCALL>".to_string()],
                tool_call_end_tokens: vec!["</TOOLCALL>".to_string()],
                ..Default::default()
145
            }),
146
147
148
149
150
151
152
        }
    }

    pub fn llama3_json() -> Self {
        // <|python_tag|>{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"} }
        // or { "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"} }
        Self {
153
            parser_config: ParserConfig::Json(JsonParserConfig {
154
155
156
                tool_call_start_tokens: vec!["<|python_tag|>".to_string()],
                tool_call_end_tokens: vec!["".to_string()],
                ..Default::default()
157
            }),
158
159
160
161
162
        }
    }

    pub fn mistral() -> Self {
        Self {
163
            parser_config: ParserConfig::Json(JsonParserConfig {
164
                tool_call_start_tokens: vec!["[TOOL_CALLS]".to_string()],
165
                tool_call_end_tokens: vec!["[/TOOL_CALLS]".to_string(), "".to_string()],
166
                ..Default::default()
167
            }),
168
169
170
171
172
        }
    }

    pub fn phi4() -> Self {
        Self {
173
            parser_config: ParserConfig::Json(JsonParserConfig {
174
175
176
                tool_call_start_tokens: vec!["functools".to_string()],
                tool_call_end_tokens: vec!["".to_string()],
                ..Default::default()
177
            }),
178
179
180
181
182
        }
    }

    pub fn pythonic() -> Self {
        Self {
183
            parser_config: ParserConfig::Pythonic,
184
185
        }
    }
186
187
188

    pub fn harmony() -> Self {
        Self {
189
            parser_config: ParserConfig::Harmony(JsonParserConfig {
190
191
192
                tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()],
                tool_call_end_tokens: vec!["<|call|>".to_string()],
                ..Default::default()
193
            }),
194
195
        }
    }
196
197

    pub fn deepseek_v3_1() -> Self {
198
199
200
201
202
203
        // The whole tool calls block is wrapped between
        // <|tool▁calls▁begin|> ... <|tool▁calls▁end|>
        // regardless of number of tool calls. For external use of this
        // config, we want them to only be operating on the whole block,
        // so the tool parser can properly consume all tool call tokens.
        // https://huggingface.co/deepseek-ai/DeepSeek-V3.1#toolcall
204
        Self {
205
            parser_config: ParserConfig::Json(JsonParserConfig {
206
207
                tool_call_start_tokens: vec![
                    "<|tool▁calls▁begin|>".to_string(),
208
                    // "<|tool▁call▁begin|>".to_string(),
209
                ],
210
211
                tool_call_end_tokens: vec![
                    "<|tool▁calls▁end|>".to_string(),
212
                    // "<|tool▁call▁end|>".to_string(),
213
214
                ],
                tool_call_separator_tokens: vec!["<|tool▁sep|>".to_string()],
215
216
                parser_type: JsonParserType::DeepseekV31,
                ..Default::default()
217
            }),
218
219
        }
    }
220
221
222
223
224
225

    pub fn deepseek_v3() -> Self {
        // DeepSeek V3 format:
        // <|tool▁calls▁begin|><|tool▁call▁begin|>{type}<|tool▁sep|>{function_name}\n```json\n{arguments}\n```<|tool▁call▁end|><|tool▁calls▁end|>
        // There are some differences between DeepSeek V3 and DeepSeek V3.1
        Self {
226
            parser_config: ParserConfig::Json(JsonParserConfig {
227
228
229
230
231
                tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()],
                tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()],
                tool_call_separator_tokens: vec!["<|tool▁sep|>".to_string()],
                parser_type: JsonParserType::DeepseekV3,
                ..Default::default()
232
            }),
233
234
        }
    }
235
236
237
238

    pub fn qwen3_coder() -> Self {
        // <tool_call><function=name><parameter=key>value</parameter></function></tool_call>
        Self {
239
            parser_config: ParserConfig::Xml(XmlParserConfig::default()),
240
241
        }
    }
242
}