base_parser.rs 19.5 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use crate::{ParserResult, ReasoningParser};
5

6
7
#[derive(Default, Debug, Clone)]
pub struct BasicReasoningParser {
8
9
10
11
12
13
14
15
    think_start_token: String,
    think_end_token: String,
    _in_reasoning: bool,
    stream_reasoning: bool,
    _buffer: String,
    stripped_think_start: bool,
}

16
impl BasicReasoningParser {
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
    pub fn new(
        think_start_token: String,
        think_end_token: String,
        force_reasoning: bool,
        stream_reasoning: bool,
    ) -> Self {
        Self {
            think_start_token,
            think_end_token,
            _in_reasoning: force_reasoning,
            stream_reasoning,
            _buffer: String::new(),
            stripped_think_start: false,
        }
    }
}

34
impl ReasoningParser for BasicReasoningParser {
35
    fn detect_and_parse_reasoning(&mut self, text: &str, _token_ids: &[u32]) -> ParserResult {
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        let in_reasoning = self._in_reasoning || text.contains(&self.think_start_token);
        if !in_reasoning {
            return ParserResult {
                normal_text: text.to_string(),
                reasoning_text: String::new(),
            };
        }

        // The text is considered to be in a reasoning block.
        let processed_text = text.replace(&self.think_start_token, "").trim().to_string();

        if !processed_text.contains(&self.think_end_token) {
            // Assume reasoning was truncated before `think_end_token`
            return ParserResult {
                normal_text: String::new(),
                reasoning_text: processed_text,
            };
        }

        // Extract reasoning content
        let splits: Vec<&str> = processed_text.splitn(2, &self.think_end_token).collect();
        let reasoning_text = splits.first().unwrap_or(&"").to_string();
        let normal_text = splits
            .get(1)
            .map(|s| s.trim().to_string())
            .unwrap_or_default();

        ParserResult {
            normal_text,
            reasoning_text,
        }
    }

69
70
71
72
73
    fn parse_reasoning_streaming_incremental(
        &mut self,
        text: &str,
        _token_ids: &[u32],
    ) -> ParserResult {
74
75
76
        // Incrementally parse the streaming text
        self._buffer.push_str(text);
        let mut current_text = self._buffer.to_string();
77
78
79
80
        // If the current text is a prefix of the think token, keep buffering.
        // Only buffer for start token if we haven't found it yet.
        // Only buffer for end token if we're currently inside a reasoning block.
        // After reasoning ends, all content passes through as normal text.
81

82
83
        if !self.stripped_think_start
            && self.think_start_token.starts_with(&current_text)
84
85
86
87
88
89
90
            && self.think_start_token.as_str() != current_text.as_str()
        {
            return ParserResult {
                normal_text: String::new(),
                reasoning_text: String::new(),
            };
        }
91
92
        if self._in_reasoning
            && self.think_end_token.starts_with(&current_text)
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
            && self.think_end_token.as_str() != current_text.as_str()
        {
            return ParserResult {
                normal_text: String::new(),
                reasoning_text: String::new(),
            };
        }

        // Strip `<think>` token if present
        if !self.stripped_think_start && current_text.contains(&self.think_start_token) {
            current_text = current_text.replace(&self.think_start_token, "");
            self._buffer = current_text.to_string();
            self.stripped_think_start = true;
            self._in_reasoning = true;
        }
        // Handle end of reasoning block
        let mut think_end_idx = current_text.len();
        if self._in_reasoning {
            think_end_idx = current_text
                .find(&self.think_end_token)
                .unwrap_or(current_text.len());
        }
        if self._in_reasoning && think_end_idx < current_text.len() {
            let reasoning_text = &current_text[..think_end_idx];
            self._buffer.clear();
            self._in_reasoning = false;
            let start_idx = think_end_idx + self.think_end_token.len();
            let normal_text = if start_idx < current_text.len() {
                &current_text[start_idx..]
            } else {
                ""
            };
            return ParserResult {
                normal_text: normal_text.to_string(),
127
                reasoning_text: reasoning_text.to_string(),
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
            };
        }
        // Continue with reasoning content
        if self._in_reasoning && self.stream_reasoning {
            // Stream the content immediately
            let reasoning_text = current_text;
            self._buffer.clear();
            ParserResult {
                normal_text: String::new(),
                reasoning_text,
            }
        } else if !self._in_reasoning {
            // If we're not in a reasoning block return as normal text
            let normal_text = current_text;
            self._buffer.clear();
            ParserResult {
                normal_text,
                reasoning_text: String::new(),
            }
        } else {
            // If we are in a reasoning block but no end token is found, return the current buffer
            ParserResult {
                normal_text: String::new(),
                reasoning_text: String::new(),
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_detect_and_parse_reasoning_reasoning() {
163
        let mut parser =
164
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
165
        let result =
166
            parser.detect_and_parse_reasoning("<think>with reasoning</think> and more text.", &[]);
167
168
169
170
171
        assert_eq!(result.normal_text, "and more text.");
        assert_eq!(result.reasoning_text, "with reasoning");
    }
    #[test]
    fn test_detect_and_parse_reasoning_reasoning_no_reasoning() {
172
        let mut parser =
173
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
174
        let result = parser.detect_and_parse_reasoning("This is a test without reasoning.", &[]);
175
176
177
178
179
        assert_eq!(result.normal_text, "This is a test without reasoning.");
        assert_eq!(result.reasoning_text, "");
    }
    #[test]
    fn test_detect_and_parse_reasoning_reasoning_truncated_reasoning() {
180
        let mut parser =
181
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
182
        let result = parser.detect_and_parse_reasoning("<think>with truncated reasoning", &[]);
183
184
185
186
187
188
189
        assert_eq!(result.normal_text, "");
        assert_eq!(result.reasoning_text, "with truncated reasoning");
    }

    #[test]
    fn test_parse_reasoning_streaming_incremental() {
        let mut parser =
190
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
191
        let result = parser.parse_reasoning_streaming_incremental("<thi", &[]);
192
193
194
195
196
197
198
        assert_eq!(result.normal_text, "");
        assert_eq!(result.reasoning_text, "");
    }

    #[test]
    fn test_parse_reasoning_streaming_incremental_complete() {
        let mut parser =
199
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
200
201
202
203
        let result = parser.parse_reasoning_streaming_incremental(
            "<think>with reasoning</think> and more text.",
            &[],
        );
204
205
206
207
208
209
210
        assert_eq!(result.normal_text, " and more text.");
        assert_eq!(result.reasoning_text, "with reasoning");
    }

    #[test]
    fn test_parse_reasoning_streaming_incremental_no_end_token() {
        let mut parser =
211
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), true, true);
212
        let result = parser.parse_reasoning_streaming_incremental("<think>with reasoning", &[]);
213
214
215
216
217
218
        assert_eq!(result.normal_text, "");
        assert_eq!(result.reasoning_text, "with reasoning");
    }

    #[test]
    fn test_detect_and_parse_reasoning_multiple_reasoning_blocks() {
219
        let mut parser =
220
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
221
222
        let result = parser.detect_and_parse_reasoning(
            "<think>first reasoning</think> middle <think>second reasoning</think> end",
223
            &[],
224
225
226
227
228
229
230
231
232
        );
        // The current implementation only handles the first occurrence properly
        assert_eq!(result.normal_text, "middle second reasoning</think> end");
        assert_eq!(result.reasoning_text, "first reasoning");
    }

    #[test]
    fn test_streaming_multiple_reasoning_blocks() {
        let mut parser =
233
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
234
235
        let result1 = parser
            .parse_reasoning_streaming_incremental("<think>first reasoning</think> middle", &[]);
236
237
238
239
        assert_eq!(result1.normal_text, " middle");
        assert_eq!(result1.reasoning_text, "first reasoning");

        // Basic parser assumes only one reasoning block at a time
240
241
        let result2 = parser
            .parse_reasoning_streaming_incremental(" <think>second reasoning</think> end", &[]);
242
243
244
245
246
247
248
        assert_eq!(result2.normal_text, " <think>second reasoning</think> end");
        assert_eq!(result2.reasoning_text, "");
    }

    #[test]
    fn test_partial_token_matching_opening_tag() {
        let mut parser =
249
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
250
251

        // Feed partial opening tag
252
        let result1 = parser.parse_reasoning_streaming_incremental("<th", &[]);
253
254
255
256
        assert_eq!(result1.normal_text, "");
        assert_eq!(result1.reasoning_text, "");

        // Complete the opening tag and add content
257
258
259
260
        let result2 = parser.parse_reasoning_streaming_incremental(
            "ink>reasoning content</think> normal text",
            &[],
        );
261
262
263
264
265
266
267
        assert_eq!(result2.normal_text, " normal text");
        assert_eq!(result2.reasoning_text, "reasoning content");
    }

    #[test]
    fn test_partial_token_matching_closing_tag() {
        let mut parser =
268
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
269
270

        // Start with complete opening and partial content
271
272
        let result1 =
            parser.parse_reasoning_streaming_incremental("<think>reasoning content</th", &[]);
273
274
275
276
        assert_eq!(result1.normal_text, "");
        assert_eq!(result1.reasoning_text, "");

        // Complete the closing tag
277
        let result2 = parser.parse_reasoning_streaming_incremental("ink> normal text", &[]);
278
279
280
281
282
283
284
        assert_eq!(result2.normal_text, " normal text");
        assert_eq!(result2.reasoning_text, "reasoning content");
    }

    #[test]
    fn test_buffer_state_persistence_across_calls() {
        let mut parser =
285
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
286
287

        // First call - partial opening tag
288
        let result1 = parser.parse_reasoning_streaming_incremental("<th", &[]);
289
290
291
292
        assert_eq!(result1.normal_text, "");
        assert_eq!(result1.reasoning_text, "");

        // Second call - complete opening tag, start reasoning
293
        let result2 = parser.parse_reasoning_streaming_incremental("ink>part1 ", &[]);
294
295
296
297
        assert_eq!(result2.normal_text, "");
        assert_eq!(result2.reasoning_text, "");

        // Third call - more reasoning content
298
        let result3 = parser.parse_reasoning_streaming_incremental("part2 ", &[]);
299
300
301
302
        assert_eq!(result3.normal_text, "");
        assert_eq!(result3.reasoning_text, "");

        // Fourth call - end reasoning and normal text
303
        let result4 = parser.parse_reasoning_streaming_incremental("part3</think> normal", &[]);
304
305
306
307
308
309
310
        assert_eq!(result4.normal_text, " normal");
        assert_eq!(result4.reasoning_text, "part1 part2 part3");
    }

    #[test]
    fn test_streaming_with_stream_reasoning_enabled() {
        let mut parser =
311
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
312
313

        // Start reasoning block
314
        let result1 = parser.parse_reasoning_streaming_incremental("<think>reasoning ", &[]);
315
316
317
318
        assert_eq!(result1.normal_text, "");
        assert_eq!(result1.reasoning_text, "reasoning ");

        // Continue streaming reasoning
319
        let result2 = parser.parse_reasoning_streaming_incremental("content ", &[]);
320
321
322
323
        assert_eq!(result2.normal_text, "");
        assert_eq!(result2.reasoning_text, "content ");

        // End reasoning block
324
        let result3 = parser.parse_reasoning_streaming_incremental("more</think> normal", &[]);
325
326
327
328
329
330
        assert_eq!(result3.normal_text, " normal");
        assert_eq!(result3.reasoning_text, "more");
    }

    #[test]
    fn test_nested_reasoning_blocks() {
331
        let mut parser =
332
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
333
334
        let result = parser.detect_and_parse_reasoning(
            "<think>outer <think>inner</think> reasoning</think> normal",
335
            &[],
336
337
338
339
340
341
342
343
344
        );
        // Current implementation should handle this by finding the first closing tag
        assert_eq!(result.normal_text, "reasoning</think> normal");
        // All <think> tags are stripped, so <think>inner is not included
        assert_eq!(result.reasoning_text, "outer inner");
    }

    #[test]
    fn test_malformed_missing_closing_tag() {
345
        let mut parser =
346
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
347
        let result = parser.detect_and_parse_reasoning("<think>reasoning without closing tag", &[]);
348
349
350
351
352
353
        assert_eq!(result.normal_text, "");
        assert_eq!(result.reasoning_text, "reasoning without closing tag");
    }

    #[test]
    fn test_malformed_stray_closing_tag() {
354
        let mut parser =
355
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
356
        let result = parser.detect_and_parse_reasoning("normal text</think> more normal", &[]);
357
358
359
360
361
362
        assert_eq!(result.normal_text, "normal text</think> more normal");
        assert_eq!(result.reasoning_text, "");
    }

    #[test]
    fn test_malformed_multiple_opening_tags() {
363
        let mut parser =
364
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
365
        let result = parser
366
            .detect_and_parse_reasoning("<think>first <think>second reasoning</think> normal", &[]);
367
368
369
370
371
372
373
        // Should handle by replacing all opening tags and using first closing tag
        assert_eq!(result.normal_text, "normal");
        assert_eq!(result.reasoning_text, "first second reasoning");
    }

    #[test]
    fn test_empty_reasoning_block() {
374
        let mut parser =
375
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
376
        let result = parser.detect_and_parse_reasoning("<think></think> normal text", &[]);
377
378
379
380
381
382
        assert_eq!(result.normal_text, "normal text");
        assert_eq!(result.reasoning_text, "");
    }

    #[test]
    fn test_whitespace_only_reasoning_block() {
383
        let mut parser =
384
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
385
        let result = parser.detect_and_parse_reasoning("<think>   \n\t  </think> normal text", &[]);
386
387
388
389
390
391
        assert_eq!(result.normal_text, "normal text");
        assert_eq!(result.reasoning_text, ""); // Should be empty after trim
    }

    #[test]
    fn test_force_reasoning_mode() {
392
        let mut parser =
393
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), true, true);
394
        let result = parser.detect_and_parse_reasoning("no think tags here", &[]);
395
396
397
398
399
400
401
        assert_eq!(result.normal_text, "");
        assert_eq!(result.reasoning_text, "no think tags here");
    }

    #[test]
    fn test_streaming_reset_state_after_complete_block() {
        let mut parser =
402
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
403
404
405

        // Process complete reasoning block
        let result1 =
406
            parser.parse_reasoning_streaming_incremental("<think>reasoning</think> normal", &[]);
407
408
409
410
        assert_eq!(result1.normal_text, " normal");
        assert_eq!(result1.reasoning_text, "reasoning");

        // Process normal text - should not be affected by previous state
411
        let result2 = parser.parse_reasoning_streaming_incremental(" more normal text", &[]);
412
413
414
415
416
        assert_eq!(result2.normal_text, " more normal text");
        assert_eq!(result2.reasoning_text, "");

        // Basic parser does not expect more than one reasoning block at a time
        // So this should not affect the state
417
418
        let result3 = parser
            .parse_reasoning_streaming_incremental(" <think>new reasoning</think> final", &[]);
419
420
421
        assert_eq!(result3.normal_text, " <think>new reasoning</think> final");
        assert_eq!(result3.reasoning_text, "");
    }
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476

    #[test]
    fn test_post_reasoning_angle_bracket_not_buffered() {
        // After reasoning ends, a standalone `<` should pass through immediately
        // as normal text. It must NOT be buffered as a potential prefix of <think>
        // or </think>, because that would cause the downstream tool call jail to
        // miss the `<` (e.g., `<invoke` becomes `invoke`).
        let mut parser =
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);

        // Process a complete reasoning block
        let r1 =
            parser.parse_reasoning_streaming_incremental("<think>reasoning content</think>", &[]);
        assert_eq!(r1.reasoning_text, "reasoning content");
        assert_eq!(r1.normal_text, "");

        // After reasoning ends, a lone `<` must pass through as normal text
        let r2 = parser.parse_reasoning_streaming_incremental("<", &[]);
        assert_eq!(r2.normal_text, "<");
        assert_eq!(r2.reasoning_text, "");

        // The next token should arrive independently (not merged with buffered `<`)
        let r3 = parser.parse_reasoning_streaming_incremental("invoke name=\"get_weather\">", &[]);
        assert_eq!(r3.normal_text, "invoke name=\"get_weather\">");
        assert_eq!(r3.reasoning_text, "");
    }

    #[test]
    fn test_post_reasoning_tool_call_xml_preserved() {
        // Simulates the MiniMax tool call scenario: reasoning followed by XML tool call.
        // The `<` in `<invoke` must not be consumed by the reasoning parser.
        let mut parser =
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);

        let r1 = parser.parse_reasoning_streaming_incremental("<think>let me check", &[]);
        assert_eq!(r1.reasoning_text, "let me check");

        let r2 = parser.parse_reasoning_streaming_incremental("</think>", &[]);
        assert_eq!(r2.normal_text, "");
        assert_eq!(r2.reasoning_text, "");

        // Tool call markers should pass through completely
        let r3 = parser.parse_reasoning_streaming_incremental("<minimax:tool_call>", &[]);
        assert_eq!(r3.normal_text, "<minimax:tool_call>");

        let r4 = parser.parse_reasoning_streaming_incremental("\n", &[]);
        assert_eq!(r4.normal_text, "\n");

        // `<` arriving as a separate token after reasoning must NOT be buffered
        let r5 = parser.parse_reasoning_streaming_incremental("<", &[]);
        assert_eq!(r5.normal_text, "<");

        let r6 = parser.parse_reasoning_streaming_incremental("invoke name=\"get_weather\">", &[]);
        assert_eq!(r6.normal_text, "invoke name=\"get_weather\">");
    }
477
}