config.rs 14.2 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
5
use super::json::JsonParserType;

6
7
8
9
10
11
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct JsonParserConfig {
    /// Start token for individual tool calls (e.g., "<TOOLCALL>")
    pub tool_call_start_tokens: Vec<String>,
    /// End token for individual tool calls (e.g., "</TOOLCALL>")
    pub tool_call_end_tokens: Vec<String>,
12
13
14
15
    /// Separator tokens between function name and arguments
    /// (e.g., "<|tool▁sep|>" for DeepSeek v3.1)
    /// Used by some models to separate function name from arguments
    pub tool_call_separator_tokens: Vec<String>,
16
17
18
19
20
21
22
23
    /// The key for the function name in the tool call
    /// i.e. `{"name": "function", "arguments": {...}}` it would be
    /// "name"
    pub function_name_keys: Vec<String>,
    /// The key for the arguments in the tool call
    /// i.e. `{"name": "function", "arguments": {...}}` it would be
    /// "arguments"
    pub arguments_keys: Vec<String>,
24
25
26
27

    /// The type of JSON parser to use
    #[serde(default)]
    pub parser_type: JsonParserType,
28
29
30
31
32
33
34
}

impl Default for JsonParserConfig {
    fn default() -> Self {
        Self {
            tool_call_start_tokens: vec!["<TOOLCALL>".to_string(), "<|python_tag|>".to_string()],
            tool_call_end_tokens: vec!["</TOOLCALL>".to_string(), "".to_string()],
35
            tool_call_separator_tokens: vec![],
36
37
            function_name_keys: vec!["name".to_string()],
            arguments_keys: vec!["arguments".to_string(), "parameters".to_string()],
38
            parser_type: JsonParserType::Basic,
39
40
41
42
        }
    }
}

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct XmlParserConfig {
    /// Start token for individual tool calls (e.g., "<tool_call>")
    pub tool_call_start_token: String,
    /// End token for individual tool calls (e.g., "</tool_call>")
    pub tool_call_end_token: String,
    /// Start token for function name (e.g., "<function=")
    pub function_start_token: String,
    /// End token for function (e.g., "</function>")
    pub function_end_token: String,
    /// Start token for parameter (e.g., "<parameter=")
    pub parameter_start_token: String,
    /// End token for parameter (e.g., "</parameter>")
    pub parameter_end_token: String,
}

impl Default for XmlParserConfig {
    fn default() -> Self {
        Self {
            tool_call_start_token: "<tool_call>".to_string(),
            tool_call_end_token: "</tool_call>".to_string(),
            function_start_token: "<function=".to_string(),
            function_end_token: "</function>".to_string(),
            parameter_start_token: "<parameter=".to_string(),
            parameter_end_token: "</parameter>".to_string(),
        }
    }
}

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
/// Configuration for DSML-style tool call parser (DeepSeek V3.2+)
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct DsmlParserConfig {
    /// Start token for function_calls block (e.g., "<|DSML|function_calls>")
    pub function_calls_start: String,
    /// End token for function_calls block (e.g., "</|DSML|function_calls>")
    pub function_calls_end: String,
    /// Start prefix for invoke (e.g., "<|DSML|invoke name=")
    pub invoke_start_prefix: String,
    /// End token for invoke (e.g., "</|DSML|invoke>")
    pub invoke_end: String,
    /// Start prefix for parameter (e.g., "<|DSML|parameter name=")
    pub parameter_prefix: String,
    /// End token for parameter (e.g., "</|DSML|parameter>")
    pub parameter_end: String,
}

impl Default for DsmlParserConfig {
    fn default() -> Self {
        Self {
            function_calls_start: "<|DSML|function_calls>".to_string(),
            function_calls_end: "</|DSML|function_calls>".to_string(),
            invoke_start_prefix: "<|DSML|invoke name=".to_string(),
            invoke_end: "</|DSML|invoke>".to_string(),
            parameter_prefix: "<|DSML|parameter name=".to_string(),
            parameter_end: "</|DSML|parameter>".to_string(),
        }
    }
}

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
/// Configuration for GLM-4.7 style tool call parser
/// Format: <tool_call>function_name<arg_key>param</arg_key><arg_value>value</arg_value></tool_call>
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct Glm47ParserConfig {
    /// Start token for tool call block (e.g., "<tool_call>")
    pub tool_call_start: String,
    /// End token for tool call block (e.g., "</tool_call>")
    pub tool_call_end: String,
    /// Start token for argument key (e.g., "<arg_key>")
    pub arg_key_start: String,
    /// End token for argument key (e.g., "</arg_key>")
    pub arg_key_end: String,
    /// Start token for argument value (e.g., "<arg_value>")
    pub arg_value_start: String,
    /// End token for argument value (e.g., "</arg_value>")
    pub arg_value_end: String,
}

impl Default for Glm47ParserConfig {
    fn default() -> Self {
        Self {
            tool_call_start: "<tool_call>".to_string(),
            tool_call_end: "</tool_call>".to_string(),
            arg_key_start: "<arg_key>".to_string(),
            arg_key_end: "</arg_key>".to_string(),
            arg_value_start: "<arg_value>".to_string(),
            arg_value_end: "</arg_value>".to_string(),
        }
    }
}

133
134
135
136
137
138
139
140
141
/// Parser-specific configuration
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ParserConfig {
    Json(JsonParserConfig),
    Xml(XmlParserConfig),
    Pythonic,
    Harmony(JsonParserConfig),
    Typescript,
142
    Dsml(DsmlParserConfig),
143
    Glm47(Glm47ParserConfig),
144
145
146
147
148
149
150
151
152
153
154
155
}

impl ParserConfig {
    /// Get the tool call start tokens for this parser configuration
    /// Returns a vector of start tokens that indicate the beginning of a tool call
    pub fn tool_call_start_tokens(&self) -> Vec<String> {
        match self {
            ParserConfig::Json(config) => config.tool_call_start_tokens.clone(),
            ParserConfig::Harmony(config) => config.tool_call_start_tokens.clone(),
            ParserConfig::Xml(config) => vec![config.tool_call_start_token.clone()],
            ParserConfig::Pythonic => vec![],
            ParserConfig::Typescript => vec![],
156
            ParserConfig::Dsml(config) => vec![config.function_calls_start.clone()],
157
            ParserConfig::Glm47(config) => vec![config.tool_call_start.clone()],
158
159
160
161
162
163
164
165
166
167
168
169
        }
    }

    /// Get the tool call end tokens for this parser configuration
    /// Returns a vector of end tokens that indicate the end of a tool call
    pub fn tool_call_end_tokens(&self) -> Vec<String> {
        match self {
            ParserConfig::Json(config) => config.tool_call_end_tokens.clone(),
            ParserConfig::Harmony(config) => config.tool_call_end_tokens.clone(),
            ParserConfig::Xml(config) => vec![config.tool_call_end_token.clone()],
            ParserConfig::Pythonic => vec![],
            ParserConfig::Typescript => vec![],
170
            ParserConfig::Dsml(config) => vec![config.function_calls_end.clone()],
171
            ParserConfig::Glm47(config) => vec![config.tool_call_end.clone()],
172
173
174
175
        }
    }
}

176
177
178
/// Configuration for parsing tool calls with different formats
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ToolCallConfig {
179
180
    /// Parser-specific configuration.
    pub parser_config: ParserConfig,
181
182
183
184
185
}

impl Default for ToolCallConfig {
    fn default() -> Self {
        Self {
186
            parser_config: ParserConfig::Json(JsonParserConfig::default()),
187
188
189
190
191
192
193
194
195
        }
    }
}

impl ToolCallConfig {
    /// Default configuration for hermes tool calls
    /// <tool_call>{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}\n</tool_call>
    pub fn hermes() -> Self {
        Self {
196
            parser_config: ParserConfig::Json(JsonParserConfig {
197
                tool_call_start_tokens: vec!["<tool_call>".to_string()],
198
                tool_call_end_tokens: vec!["</tool_call>".to_string()],
199
                ..Default::default()
200
            }),
201
202
203
204
205
206
207
        }
    }

    /// Default configuration for nemotron tool calls
    /// <TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]</TOOLCALL>
    pub fn nemotron_deci() -> Self {
        Self {
208
            parser_config: ParserConfig::Json(JsonParserConfig {
209
210
211
                tool_call_start_tokens: vec!["<TOOLCALL>".to_string()],
                tool_call_end_tokens: vec!["</TOOLCALL>".to_string()],
                ..Default::default()
212
            }),
213
214
215
216
217
218
219
        }
    }

    pub fn llama3_json() -> Self {
        // <|python_tag|>{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"} }
        // or { "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"} }
        Self {
220
            parser_config: ParserConfig::Json(JsonParserConfig {
221
222
223
                tool_call_start_tokens: vec!["<|python_tag|>".to_string()],
                tool_call_end_tokens: vec!["".to_string()],
                ..Default::default()
224
            }),
225
226
227
228
229
        }
    }

    pub fn mistral() -> Self {
        Self {
230
            parser_config: ParserConfig::Json(JsonParserConfig {
231
                tool_call_start_tokens: vec!["[TOOL_CALLS]".to_string()],
232
                tool_call_end_tokens: vec!["[/TOOL_CALLS]".to_string(), "".to_string()],
233
                ..Default::default()
234
            }),
235
236
237
238
239
        }
    }

    pub fn phi4() -> Self {
        Self {
240
            parser_config: ParserConfig::Json(JsonParserConfig {
241
242
243
                tool_call_start_tokens: vec!["functools".to_string()],
                tool_call_end_tokens: vec!["".to_string()],
                ..Default::default()
244
            }),
245
246
247
248
249
        }
    }

    pub fn pythonic() -> Self {
        Self {
250
            parser_config: ParserConfig::Pythonic,
251
252
        }
    }
253
254
255

    pub fn harmony() -> Self {
        Self {
256
            parser_config: ParserConfig::Harmony(JsonParserConfig {
257
258
259
                tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()],
                tool_call_end_tokens: vec!["<|call|>".to_string()],
                ..Default::default()
260
            }),
261
262
        }
    }
263
264

    pub fn deepseek_v3_1() -> Self {
265
266
267
268
269
270
        // The whole tool calls block is wrapped between
        // <|tool▁calls▁begin|> ... <|tool▁calls▁end|>
        // regardless of number of tool calls. For external use of this
        // config, we want them to only be operating on the whole block,
        // so the tool parser can properly consume all tool call tokens.
        // https://huggingface.co/deepseek-ai/DeepSeek-V3.1#toolcall
271
        Self {
272
            parser_config: ParserConfig::Json(JsonParserConfig {
273
274
                tool_call_start_tokens: vec![
                    "<|tool▁calls▁begin|>".to_string(),
275
                    // "<|tool▁call▁begin|>".to_string(),
276
                ],
277
278
                tool_call_end_tokens: vec![
                    "<|tool▁calls▁end|>".to_string(),
279
                    // "<|tool▁call▁end|>".to_string(),
280
281
                ],
                tool_call_separator_tokens: vec!["<|tool▁sep|>".to_string()],
282
283
                parser_type: JsonParserType::DeepseekV31,
                ..Default::default()
284
            }),
285
286
        }
    }
287
288
289
290
291
292

    pub fn deepseek_v3() -> Self {
        // DeepSeek V3 format:
        // <|tool▁calls▁begin|><|tool▁call▁begin|>{type}<|tool▁sep|>{function_name}\n```json\n{arguments}\n```<|tool▁call▁end|><|tool▁calls▁end|>
        // There are some differences between DeepSeek V3 and DeepSeek V3.1
        Self {
293
            parser_config: ParserConfig::Json(JsonParserConfig {
294
295
296
297
298
                tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()],
                tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()],
                tool_call_separator_tokens: vec!["<|tool▁sep|>".to_string()],
                parser_type: JsonParserType::DeepseekV3,
                ..Default::default()
299
            }),
300
301
        }
    }
302
303
304
305

    pub fn qwen3_coder() -> Self {
        // <tool_call><function=name><parameter=key>value</parameter></function></tool_call>
        Self {
306
            parser_config: ParserConfig::Xml(XmlParserConfig::default()),
307
308
        }
    }
309
310
311
312
313
314
315
316
317
318

    pub fn jamba() -> Self {
        Self {
            parser_config: ParserConfig::Json(JsonParserConfig {
                tool_call_start_tokens: vec!["<tool_calls>".to_string()],
                tool_call_end_tokens: vec!["</tool_calls>".to_string()],
                ..Default::default()
            }),
        }
    }
319
320
321
322
323
324
325
326
327
328
329
330

    pub fn deepseek_v3_2() -> Self {
        // DeepSeek V3.2 format (DSML):
        // <|DSML|function_calls>
        // <|DSML|invoke name="function_name">
        // <|DSML|parameter name="param_name" string="true|false">value</|DSML|parameter>
        // </|DSML|invoke>
        // </|DSML|function_calls>
        Self {
            parser_config: ParserConfig::Dsml(DsmlParserConfig::default()),
        }
    }
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350

    pub fn minimax_m2() -> Self {
        // MiniMax-M2.1 format:
        // <minimax:tool_call>
        // <invoke name="function_name">
        // <parameter name="param_name">value</parameter>
        // </invoke>
        // </minimax:tool_call>
        // Reference: https://huggingface.co/MiniMaxAI/MiniMax-M2.1/blob/main/docs/tool_calling_guide.md
        Self {
            parser_config: ParserConfig::Xml(XmlParserConfig {
                tool_call_start_token: "<minimax:tool_call>".to_string(),
                tool_call_end_token: "</minimax:tool_call>".to_string(),
                function_start_token: "<invoke name=".to_string(),
                function_end_token: "</invoke>".to_string(),
                parameter_start_token: "<parameter name=".to_string(),
                parameter_end_token: "</parameter>".to_string(),
            }),
        }
    }
351
352
353
354
355
356
357
358
359

    pub fn glm47() -> Self {
        // GLM-4.7 format:
        // <tool_call>function_name<arg_key>param1</arg_key><arg_value>value1</arg_value></tool_call>
        // Reference: https://huggingface.co/zai-org/GLM-4.7/blob/main/chat_template.jinja
        Self {
            parser_config: ParserConfig::Glm47(Glm47ParserConfig::default()),
        }
    }
360
}