base_parser.rs 16.6 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use crate::{ParserResult, ReasoningParser};
5

6
7
#[derive(Default, Debug, Clone)]
pub struct BasicReasoningParser {
8
9
10
11
12
13
14
15
    think_start_token: String,
    think_end_token: String,
    _in_reasoning: bool,
    stream_reasoning: bool,
    _buffer: String,
    stripped_think_start: bool,
}

16
impl BasicReasoningParser {
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
    pub fn new(
        think_start_token: String,
        think_end_token: String,
        force_reasoning: bool,
        stream_reasoning: bool,
    ) -> Self {
        Self {
            think_start_token,
            think_end_token,
            _in_reasoning: force_reasoning,
            stream_reasoning,
            _buffer: String::new(),
            stripped_think_start: false,
        }
    }
}

34
impl ReasoningParser for BasicReasoningParser {
35
    fn detect_and_parse_reasoning(&mut self, text: &str, _token_ids: &[u32]) -> ParserResult {
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        let in_reasoning = self._in_reasoning || text.contains(&self.think_start_token);
        if !in_reasoning {
            return ParserResult {
                normal_text: text.to_string(),
                reasoning_text: String::new(),
            };
        }

        // The text is considered to be in a reasoning block.
        let processed_text = text.replace(&self.think_start_token, "").trim().to_string();

        if !processed_text.contains(&self.think_end_token) {
            // Assume reasoning was truncated before `think_end_token`
            return ParserResult {
                normal_text: String::new(),
                reasoning_text: processed_text,
            };
        }

        // Extract reasoning content
        let splits: Vec<&str> = processed_text.splitn(2, &self.think_end_token).collect();
        let reasoning_text = splits.first().unwrap_or(&"").to_string();
        let normal_text = splits
            .get(1)
            .map(|s| s.trim().to_string())
            .unwrap_or_default();

        ParserResult {
            normal_text,
            reasoning_text,
        }
    }

69
70
71
72
73
    fn parse_reasoning_streaming_incremental(
        &mut self,
        text: &str,
        _token_ids: &[u32],
    ) -> ParserResult {
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        // Incrementally parse the streaming text
        self._buffer.push_str(text);
        let mut current_text = self._buffer.to_string();
        // If the current text is a prefix of the think token, keep buffering

        if self.think_start_token.starts_with(&current_text)
            && self.think_start_token.as_str() != current_text.as_str()
        {
            return ParserResult {
                normal_text: String::new(),
                reasoning_text: String::new(),
            };
        }
        if self.think_end_token.starts_with(&current_text)
            && self.think_end_token.as_str() != current_text.as_str()
        {
            return ParserResult {
                normal_text: String::new(),
                reasoning_text: String::new(),
            };
        }

        // Strip `<think>` token if present
        if !self.stripped_think_start && current_text.contains(&self.think_start_token) {
            current_text = current_text.replace(&self.think_start_token, "");
            self._buffer = current_text.to_string();
            self.stripped_think_start = true;
            self._in_reasoning = true;
        }
        // Handle end of reasoning block
        let mut think_end_idx = current_text.len();
        if self._in_reasoning {
            think_end_idx = current_text
                .find(&self.think_end_token)
                .unwrap_or(current_text.len());
        }
        if self._in_reasoning && think_end_idx < current_text.len() {
            let reasoning_text = &current_text[..think_end_idx];
            self._buffer.clear();
            self._in_reasoning = false;
            let start_idx = think_end_idx + self.think_end_token.len();
            let normal_text = if start_idx < current_text.len() {
                &current_text[start_idx..]
            } else {
                ""
            };
            return ParserResult {
                normal_text: normal_text.to_string(),
                reasoning_text: reasoning_text.trim().to_string(),
            };
        }
        // Continue with reasoning content
        if self._in_reasoning && self.stream_reasoning {
            // Stream the content immediately
            let reasoning_text = current_text;
            self._buffer.clear();
            ParserResult {
                normal_text: String::new(),
                reasoning_text,
            }
        } else if !self._in_reasoning {
            // If we're not in a reasoning block return as normal text
            let normal_text = current_text;
            self._buffer.clear();
            ParserResult {
                normal_text,
                reasoning_text: String::new(),
            }
        } else {
            // If we are in a reasoning block but no end token is found, return the current buffer
            ParserResult {
                normal_text: String::new(),
                reasoning_text: String::new(),
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_detect_and_parse_reasoning_reasoning() {
158
        let mut parser =
159
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
160
        let result =
161
            parser.detect_and_parse_reasoning("<think>with reasoning</think> and more text.", &[]);
162
163
164
165
166
        assert_eq!(result.normal_text, "and more text.");
        assert_eq!(result.reasoning_text, "with reasoning");
    }
    #[test]
    fn test_detect_and_parse_reasoning_reasoning_no_reasoning() {
167
        let mut parser =
168
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
169
        let result = parser.detect_and_parse_reasoning("This is a test without reasoning.", &[]);
170
171
172
173
174
        assert_eq!(result.normal_text, "This is a test without reasoning.");
        assert_eq!(result.reasoning_text, "");
    }
    #[test]
    fn test_detect_and_parse_reasoning_reasoning_truncated_reasoning() {
175
        let mut parser =
176
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
177
        let result = parser.detect_and_parse_reasoning("<think>with truncated reasoning", &[]);
178
179
180
181
182
183
184
        assert_eq!(result.normal_text, "");
        assert_eq!(result.reasoning_text, "with truncated reasoning");
    }

    #[test]
    fn test_parse_reasoning_streaming_incremental() {
        let mut parser =
185
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
186
        let result = parser.parse_reasoning_streaming_incremental("<thi", &[]);
187
188
189
190
191
192
193
        assert_eq!(result.normal_text, "");
        assert_eq!(result.reasoning_text, "");
    }

    #[test]
    fn test_parse_reasoning_streaming_incremental_complete() {
        let mut parser =
194
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
195
196
197
198
        let result = parser.parse_reasoning_streaming_incremental(
            "<think>with reasoning</think> and more text.",
            &[],
        );
199
200
201
202
203
204
205
        assert_eq!(result.normal_text, " and more text.");
        assert_eq!(result.reasoning_text, "with reasoning");
    }

    #[test]
    fn test_parse_reasoning_streaming_incremental_no_end_token() {
        let mut parser =
206
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), true, true);
207
        let result = parser.parse_reasoning_streaming_incremental("<think>with reasoning", &[]);
208
209
210
211
212
213
        assert_eq!(result.normal_text, "");
        assert_eq!(result.reasoning_text, "with reasoning");
    }

    #[test]
    fn test_detect_and_parse_reasoning_multiple_reasoning_blocks() {
214
        let mut parser =
215
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
216
217
        let result = parser.detect_and_parse_reasoning(
            "<think>first reasoning</think> middle <think>second reasoning</think> end",
218
            &[],
219
220
221
222
223
224
225
226
227
        );
        // The current implementation only handles the first occurrence properly
        assert_eq!(result.normal_text, "middle second reasoning</think> end");
        assert_eq!(result.reasoning_text, "first reasoning");
    }

    #[test]
    fn test_streaming_multiple_reasoning_blocks() {
        let mut parser =
228
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
229
230
        let result1 = parser
            .parse_reasoning_streaming_incremental("<think>first reasoning</think> middle", &[]);
231
232
233
234
        assert_eq!(result1.normal_text, " middle");
        assert_eq!(result1.reasoning_text, "first reasoning");

        // Basic parser assumes only one reasoning block at a time
235
236
        let result2 = parser
            .parse_reasoning_streaming_incremental(" <think>second reasoning</think> end", &[]);
237
238
239
240
241
242
243
        assert_eq!(result2.normal_text, " <think>second reasoning</think> end");
        assert_eq!(result2.reasoning_text, "");
    }

    #[test]
    fn test_partial_token_matching_opening_tag() {
        let mut parser =
244
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
245
246

        // Feed partial opening tag
247
        let result1 = parser.parse_reasoning_streaming_incremental("<th", &[]);
248
249
250
251
        assert_eq!(result1.normal_text, "");
        assert_eq!(result1.reasoning_text, "");

        // Complete the opening tag and add content
252
253
254
255
        let result2 = parser.parse_reasoning_streaming_incremental(
            "ink>reasoning content</think> normal text",
            &[],
        );
256
257
258
259
260
261
262
        assert_eq!(result2.normal_text, " normal text");
        assert_eq!(result2.reasoning_text, "reasoning content");
    }

    #[test]
    fn test_partial_token_matching_closing_tag() {
        let mut parser =
263
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
264
265

        // Start with complete opening and partial content
266
267
        let result1 =
            parser.parse_reasoning_streaming_incremental("<think>reasoning content</th", &[]);
268
269
270
271
        assert_eq!(result1.normal_text, "");
        assert_eq!(result1.reasoning_text, "");

        // Complete the closing tag
272
        let result2 = parser.parse_reasoning_streaming_incremental("ink> normal text", &[]);
273
274
275
276
277
278
279
        assert_eq!(result2.normal_text, " normal text");
        assert_eq!(result2.reasoning_text, "reasoning content");
    }

    #[test]
    fn test_buffer_state_persistence_across_calls() {
        let mut parser =
280
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
281
282

        // First call - partial opening tag
283
        let result1 = parser.parse_reasoning_streaming_incremental("<th", &[]);
284
285
286
287
        assert_eq!(result1.normal_text, "");
        assert_eq!(result1.reasoning_text, "");

        // Second call - complete opening tag, start reasoning
288
        let result2 = parser.parse_reasoning_streaming_incremental("ink>part1 ", &[]);
289
290
291
292
        assert_eq!(result2.normal_text, "");
        assert_eq!(result2.reasoning_text, "");

        // Third call - more reasoning content
293
        let result3 = parser.parse_reasoning_streaming_incremental("part2 ", &[]);
294
295
296
297
        assert_eq!(result3.normal_text, "");
        assert_eq!(result3.reasoning_text, "");

        // Fourth call - end reasoning and normal text
298
        let result4 = parser.parse_reasoning_streaming_incremental("part3</think> normal", &[]);
299
300
301
302
303
304
305
        assert_eq!(result4.normal_text, " normal");
        assert_eq!(result4.reasoning_text, "part1 part2 part3");
    }

    #[test]
    fn test_streaming_with_stream_reasoning_enabled() {
        let mut parser =
306
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
307
308

        // Start reasoning block
309
        let result1 = parser.parse_reasoning_streaming_incremental("<think>reasoning ", &[]);
310
311
312
313
        assert_eq!(result1.normal_text, "");
        assert_eq!(result1.reasoning_text, "reasoning ");

        // Continue streaming reasoning
314
        let result2 = parser.parse_reasoning_streaming_incremental("content ", &[]);
315
316
317
318
        assert_eq!(result2.normal_text, "");
        assert_eq!(result2.reasoning_text, "content ");

        // End reasoning block
319
        let result3 = parser.parse_reasoning_streaming_incremental("more</think> normal", &[]);
320
321
322
323
324
325
        assert_eq!(result3.normal_text, " normal");
        assert_eq!(result3.reasoning_text, "more");
    }

    #[test]
    fn test_nested_reasoning_blocks() {
326
        let mut parser =
327
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
328
329
        let result = parser.detect_and_parse_reasoning(
            "<think>outer <think>inner</think> reasoning</think> normal",
330
            &[],
331
332
333
334
335
336
337
338
339
        );
        // Current implementation should handle this by finding the first closing tag
        assert_eq!(result.normal_text, "reasoning</think> normal");
        // All <think> tags are stripped, so <think>inner is not included
        assert_eq!(result.reasoning_text, "outer inner");
    }

    #[test]
    fn test_malformed_missing_closing_tag() {
340
        let mut parser =
341
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
342
        let result = parser.detect_and_parse_reasoning("<think>reasoning without closing tag", &[]);
343
344
345
346
347
348
        assert_eq!(result.normal_text, "");
        assert_eq!(result.reasoning_text, "reasoning without closing tag");
    }

    #[test]
    fn test_malformed_stray_closing_tag() {
349
        let mut parser =
350
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
351
        let result = parser.detect_and_parse_reasoning("normal text</think> more normal", &[]);
352
353
354
355
356
357
        assert_eq!(result.normal_text, "normal text</think> more normal");
        assert_eq!(result.reasoning_text, "");
    }

    #[test]
    fn test_malformed_multiple_opening_tags() {
358
        let mut parser =
359
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
360
        let result = parser
361
            .detect_and_parse_reasoning("<think>first <think>second reasoning</think> normal", &[]);
362
363
364
365
366
367
368
        // Should handle by replacing all opening tags and using first closing tag
        assert_eq!(result.normal_text, "normal");
        assert_eq!(result.reasoning_text, "first second reasoning");
    }

    #[test]
    fn test_empty_reasoning_block() {
369
        let mut parser =
370
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
371
        let result = parser.detect_and_parse_reasoning("<think></think> normal text", &[]);
372
373
374
375
376
377
        assert_eq!(result.normal_text, "normal text");
        assert_eq!(result.reasoning_text, "");
    }

    #[test]
    fn test_whitespace_only_reasoning_block() {
378
        let mut parser =
379
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
380
        let result = parser.detect_and_parse_reasoning("<think>   \n\t  </think> normal text", &[]);
381
382
383
384
385
386
        assert_eq!(result.normal_text, "normal text");
        assert_eq!(result.reasoning_text, ""); // Should be empty after trim
    }

    #[test]
    fn test_force_reasoning_mode() {
387
        let mut parser =
388
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), true, true);
389
        let result = parser.detect_and_parse_reasoning("no think tags here", &[]);
390
391
392
393
394
395
396
        assert_eq!(result.normal_text, "");
        assert_eq!(result.reasoning_text, "no think tags here");
    }

    #[test]
    fn test_streaming_reset_state_after_complete_block() {
        let mut parser =
397
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
398
399
400

        // Process complete reasoning block
        let result1 =
401
            parser.parse_reasoning_streaming_incremental("<think>reasoning</think> normal", &[]);
402
403
404
405
        assert_eq!(result1.normal_text, " normal");
        assert_eq!(result1.reasoning_text, "reasoning");

        // Process normal text - should not be affected by previous state
406
        let result2 = parser.parse_reasoning_streaming_incremental(" more normal text", &[]);
407
408
409
410
411
        assert_eq!(result2.normal_text, " more normal text");
        assert_eq!(result2.reasoning_text, "");

        // Basic parser does not expect more than one reasoning block at a time
        // So this should not affect the state
412
413
        let result3 = parser
            .parse_reasoning_streaming_incremental(" <think>new reasoning</think> final", &[]);
414
415
416
417
        assert_eq!(result3.normal_text, " <think>new reasoning</think> final");
        assert_eq!(result3.reasoning_text, "");
    }
}