base_parser.rs 17.6 KB
Newer Older
1
2
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
3
use tracing as log;
4

5
use crate::{ParserResult, ReasoningParser};
6

7
8
#[derive(Default, Debug, Clone)]
pub struct BasicReasoningParser {
9
10
11
12
13
14
15
16
    think_start_token: String,
    think_end_token: String,
    _in_reasoning: bool,
    stream_reasoning: bool,
    _buffer: String,
    stripped_think_start: bool,
}

17
impl BasicReasoningParser {
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    pub fn new(
        think_start_token: String,
        think_end_token: String,
        force_reasoning: bool,
        stream_reasoning: bool,
    ) -> Self {
        Self {
            think_start_token,
            think_end_token,
            _in_reasoning: force_reasoning,
            stream_reasoning,
            _buffer: String::new(),
            stripped_think_start: false,
        }
    }
}

35
impl ReasoningParser for BasicReasoningParser {
36
    fn detect_and_parse_reasoning(&mut self, text: &str, _token_ids: &[u32]) -> ParserResult {
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        log::debug!("detect_and_parse_reasoning called with text: {:?}", text);

        let in_reasoning = self._in_reasoning || text.contains(&self.think_start_token);
        log::debug!("in_reasoning: {}", in_reasoning);

        if !in_reasoning {
            log::debug!("No reasoning detected, returning normal text.");
            return ParserResult {
                normal_text: text.to_string(),
                reasoning_text: String::new(),
            };
        }

        // The text is considered to be in a reasoning block.
        let processed_text = text.replace(&self.think_start_token, "").trim().to_string();
        log::debug!(
            "Processed text after removing think_start_token: {:?}",
            processed_text
        );

        if !processed_text.contains(&self.think_end_token) {
            log::debug!(
                "Reasoning truncated, think_end_token not found. Returning reasoning text."
            );
            // Assume reasoning was truncated before `think_end_token`
            return ParserResult {
                normal_text: String::new(),
                reasoning_text: processed_text,
            };
        }

        // Extract reasoning content
        let splits: Vec<&str> = processed_text.splitn(2, &self.think_end_token).collect();
        let reasoning_text = splits.first().unwrap_or(&"").to_string();
        let normal_text = splits
            .get(1)
            .map(|s| s.trim().to_string())
            .unwrap_or_default();

        log::debug!("Extracted reasoning_text: {:?}", reasoning_text);
        log::debug!("Extracted normal_text: {:?}", normal_text);

        ParserResult {
            normal_text,
            reasoning_text,
        }
    }

85
86
87
88
89
    fn parse_reasoning_streaming_incremental(
        &mut self,
        text: &str,
        _token_ids: &[u32],
    ) -> ParserResult {
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        // Incrementally parse the streaming text
        self._buffer.push_str(text);
        let mut current_text = self._buffer.to_string();
        // If the current text is a prefix of the think token, keep buffering

        log::debug!(
            "parse_reasoning_streaming_incremental called with text: {:?}",
            text
        );
        log::debug!("current buffer: {:?}", self._buffer);
        log::debug!("current_text: {:?}", current_text);
        log::debug!(
            "in_reasoning: {}, stripped_think_start: {}, stream_reasoning: {}",
            self._in_reasoning,
            self.stripped_think_start,
            self.stream_reasoning
        );

        if self.think_start_token.starts_with(&current_text)
            && self.think_start_token.as_str() != current_text.as_str()
        {
            return ParserResult {
                normal_text: String::new(),
                reasoning_text: String::new(),
            };
        }
        if self.think_end_token.starts_with(&current_text)
            && self.think_end_token.as_str() != current_text.as_str()
        {
            return ParserResult {
                normal_text: String::new(),
                reasoning_text: String::new(),
            };
        }

        // Strip `<think>` token if present
        if !self.stripped_think_start && current_text.contains(&self.think_start_token) {
            current_text = current_text.replace(&self.think_start_token, "");
            self._buffer = current_text.to_string();
            self.stripped_think_start = true;
            self._in_reasoning = true;
        }
        // Handle end of reasoning block
        let mut think_end_idx = current_text.len();
        if self._in_reasoning {
            think_end_idx = current_text
                .find(&self.think_end_token)
                .unwrap_or(current_text.len());
        }
        if self._in_reasoning && think_end_idx < current_text.len() {
            let reasoning_text = &current_text[..think_end_idx];
            self._buffer.clear();
            self._in_reasoning = false;
            let start_idx = think_end_idx + self.think_end_token.len();
            let normal_text = if start_idx < current_text.len() {
                &current_text[start_idx..]
            } else {
                ""
            };
            return ParserResult {
                normal_text: normal_text.to_string(),
                reasoning_text: reasoning_text.trim().to_string(),
            };
        }
        // Continue with reasoning content
        if self._in_reasoning && self.stream_reasoning {
            // Stream the content immediately
            let reasoning_text = current_text;
            self._buffer.clear();
            ParserResult {
                normal_text: String::new(),
                reasoning_text,
            }
        } else if !self._in_reasoning {
            // If we're not in a reasoning block return as normal text
            let normal_text = current_text;
            self._buffer.clear();
            ParserResult {
                normal_text,
                reasoning_text: String::new(),
            }
        } else {
            // If we are in a reasoning block but no end token is found, return the current buffer
            ParserResult {
                normal_text: String::new(),
                reasoning_text: String::new(),
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_detect_and_parse_reasoning_reasoning() {
187
        let mut parser =
188
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
189
        let result =
190
            parser.detect_and_parse_reasoning("<think>with reasoning</think> and more text.", &[]);
191
192
193
194
195
        assert_eq!(result.normal_text, "and more text.");
        assert_eq!(result.reasoning_text, "with reasoning");
    }
    #[test]
    fn test_detect_and_parse_reasoning_reasoning_no_reasoning() {
196
        let mut parser =
197
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
198
        let result = parser.detect_and_parse_reasoning("This is a test without reasoning.", &[]);
199
200
201
202
203
        assert_eq!(result.normal_text, "This is a test without reasoning.");
        assert_eq!(result.reasoning_text, "");
    }
    #[test]
    fn test_detect_and_parse_reasoning_reasoning_truncated_reasoning() {
204
        let mut parser =
205
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
206
        let result = parser.detect_and_parse_reasoning("<think>with truncated reasoning", &[]);
207
208
209
210
211
212
213
        assert_eq!(result.normal_text, "");
        assert_eq!(result.reasoning_text, "with truncated reasoning");
    }

    #[test]
    fn test_parse_reasoning_streaming_incremental() {
        let mut parser =
214
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
215
        let result = parser.parse_reasoning_streaming_incremental("<thi", &[]);
216
217
218
219
220
221
222
        assert_eq!(result.normal_text, "");
        assert_eq!(result.reasoning_text, "");
    }

    #[test]
    fn test_parse_reasoning_streaming_incremental_complete() {
        let mut parser =
223
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
224
225
226
227
        let result = parser.parse_reasoning_streaming_incremental(
            "<think>with reasoning</think> and more text.",
            &[],
        );
228
229
230
231
232
233
234
        assert_eq!(result.normal_text, " and more text.");
        assert_eq!(result.reasoning_text, "with reasoning");
    }

    #[test]
    fn test_parse_reasoning_streaming_incremental_no_end_token() {
        let mut parser =
235
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), true, true);
236
        let result = parser.parse_reasoning_streaming_incremental("<think>with reasoning", &[]);
237
238
239
240
241
242
        assert_eq!(result.normal_text, "");
        assert_eq!(result.reasoning_text, "with reasoning");
    }

    #[test]
    fn test_detect_and_parse_reasoning_multiple_reasoning_blocks() {
243
        let mut parser =
244
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
245
246
        let result = parser.detect_and_parse_reasoning(
            "<think>first reasoning</think> middle <think>second reasoning</think> end",
247
            &[],
248
249
250
251
252
253
254
255
256
        );
        // The current implementation only handles the first occurrence properly
        assert_eq!(result.normal_text, "middle second reasoning</think> end");
        assert_eq!(result.reasoning_text, "first reasoning");
    }

    #[test]
    fn test_streaming_multiple_reasoning_blocks() {
        let mut parser =
257
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
258
259
        let result1 = parser
            .parse_reasoning_streaming_incremental("<think>first reasoning</think> middle", &[]);
260
261
262
263
        assert_eq!(result1.normal_text, " middle");
        assert_eq!(result1.reasoning_text, "first reasoning");

        // Basic parser assumes only one reasoning block at a time
264
265
        let result2 = parser
            .parse_reasoning_streaming_incremental(" <think>second reasoning</think> end", &[]);
266
267
268
269
270
271
272
        assert_eq!(result2.normal_text, " <think>second reasoning</think> end");
        assert_eq!(result2.reasoning_text, "");
    }

    #[test]
    fn test_partial_token_matching_opening_tag() {
        let mut parser =
273
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
274
275

        // Feed partial opening tag
276
        let result1 = parser.parse_reasoning_streaming_incremental("<th", &[]);
277
278
279
280
        assert_eq!(result1.normal_text, "");
        assert_eq!(result1.reasoning_text, "");

        // Complete the opening tag and add content
281
282
283
284
        let result2 = parser.parse_reasoning_streaming_incremental(
            "ink>reasoning content</think> normal text",
            &[],
        );
285
286
287
288
289
290
291
        assert_eq!(result2.normal_text, " normal text");
        assert_eq!(result2.reasoning_text, "reasoning content");
    }

    #[test]
    fn test_partial_token_matching_closing_tag() {
        let mut parser =
292
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
293
294

        // Start with complete opening and partial content
295
296
        let result1 =
            parser.parse_reasoning_streaming_incremental("<think>reasoning content</th", &[]);
297
298
299
300
        assert_eq!(result1.normal_text, "");
        assert_eq!(result1.reasoning_text, "");

        // Complete the closing tag
301
        let result2 = parser.parse_reasoning_streaming_incremental("ink> normal text", &[]);
302
303
304
305
306
307
308
        assert_eq!(result2.normal_text, " normal text");
        assert_eq!(result2.reasoning_text, "reasoning content");
    }

    #[test]
    fn test_buffer_state_persistence_across_calls() {
        let mut parser =
309
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
310
311

        // First call - partial opening tag
312
        let result1 = parser.parse_reasoning_streaming_incremental("<th", &[]);
313
314
315
316
        assert_eq!(result1.normal_text, "");
        assert_eq!(result1.reasoning_text, "");

        // Second call - complete opening tag, start reasoning
317
        let result2 = parser.parse_reasoning_streaming_incremental("ink>part1 ", &[]);
318
319
320
321
        assert_eq!(result2.normal_text, "");
        assert_eq!(result2.reasoning_text, "");

        // Third call - more reasoning content
322
        let result3 = parser.parse_reasoning_streaming_incremental("part2 ", &[]);
323
324
325
326
        assert_eq!(result3.normal_text, "");
        assert_eq!(result3.reasoning_text, "");

        // Fourth call - end reasoning and normal text
327
        let result4 = parser.parse_reasoning_streaming_incremental("part3</think> normal", &[]);
328
329
330
331
332
333
334
        assert_eq!(result4.normal_text, " normal");
        assert_eq!(result4.reasoning_text, "part1 part2 part3");
    }

    #[test]
    fn test_streaming_with_stream_reasoning_enabled() {
        let mut parser =
335
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
336
337

        // Start reasoning block
338
        let result1 = parser.parse_reasoning_streaming_incremental("<think>reasoning ", &[]);
339
340
341
342
        assert_eq!(result1.normal_text, "");
        assert_eq!(result1.reasoning_text, "reasoning ");

        // Continue streaming reasoning
343
        let result2 = parser.parse_reasoning_streaming_incremental("content ", &[]);
344
345
346
347
        assert_eq!(result2.normal_text, "");
        assert_eq!(result2.reasoning_text, "content ");

        // End reasoning block
348
        let result3 = parser.parse_reasoning_streaming_incremental("more</think> normal", &[]);
349
350
351
352
353
354
        assert_eq!(result3.normal_text, " normal");
        assert_eq!(result3.reasoning_text, "more");
    }

    #[test]
    fn test_nested_reasoning_blocks() {
355
        let mut parser =
356
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
357
358
        let result = parser.detect_and_parse_reasoning(
            "<think>outer <think>inner</think> reasoning</think> normal",
359
            &[],
360
361
362
363
364
365
366
367
368
        );
        // Current implementation should handle this by finding the first closing tag
        assert_eq!(result.normal_text, "reasoning</think> normal");
        // All <think> tags are stripped, so <think>inner is not included
        assert_eq!(result.reasoning_text, "outer inner");
    }

    #[test]
    fn test_malformed_missing_closing_tag() {
369
        let mut parser =
370
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
371
        let result = parser.detect_and_parse_reasoning("<think>reasoning without closing tag", &[]);
372
373
374
375
376
377
        assert_eq!(result.normal_text, "");
        assert_eq!(result.reasoning_text, "reasoning without closing tag");
    }

    #[test]
    fn test_malformed_stray_closing_tag() {
378
        let mut parser =
379
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
380
        let result = parser.detect_and_parse_reasoning("normal text</think> more normal", &[]);
381
382
383
384
385
386
        assert_eq!(result.normal_text, "normal text</think> more normal");
        assert_eq!(result.reasoning_text, "");
    }

    #[test]
    fn test_malformed_multiple_opening_tags() {
387
        let mut parser =
388
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
389
        let result = parser
390
            .detect_and_parse_reasoning("<think>first <think>second reasoning</think> normal", &[]);
391
392
393
394
395
396
397
        // Should handle by replacing all opening tags and using first closing tag
        assert_eq!(result.normal_text, "normal");
        assert_eq!(result.reasoning_text, "first second reasoning");
    }

    #[test]
    fn test_empty_reasoning_block() {
398
        let mut parser =
399
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
400
        let result = parser.detect_and_parse_reasoning("<think></think> normal text", &[]);
401
402
403
404
405
406
        assert_eq!(result.normal_text, "normal text");
        assert_eq!(result.reasoning_text, "");
    }

    #[test]
    fn test_whitespace_only_reasoning_block() {
407
        let mut parser =
408
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
409
        let result = parser.detect_and_parse_reasoning("<think>   \n\t  </think> normal text", &[]);
410
411
412
413
414
415
        assert_eq!(result.normal_text, "normal text");
        assert_eq!(result.reasoning_text, ""); // Should be empty after trim
    }

    #[test]
    fn test_force_reasoning_mode() {
416
        let mut parser =
417
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), true, true);
418
        let result = parser.detect_and_parse_reasoning("no think tags here", &[]);
419
420
421
422
423
424
425
        assert_eq!(result.normal_text, "");
        assert_eq!(result.reasoning_text, "no think tags here");
    }

    #[test]
    fn test_streaming_reset_state_after_complete_block() {
        let mut parser =
426
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
427
428
429

        // Process complete reasoning block
        let result1 =
430
            parser.parse_reasoning_streaming_incremental("<think>reasoning</think> normal", &[]);
431
432
433
434
        assert_eq!(result1.normal_text, " normal");
        assert_eq!(result1.reasoning_text, "reasoning");

        // Process normal text - should not be affected by previous state
435
        let result2 = parser.parse_reasoning_streaming_incremental(" more normal text", &[]);
436
437
438
439
440
        assert_eq!(result2.normal_text, " more normal text");
        assert_eq!(result2.reasoning_text, "");

        // Basic parser does not expect more than one reasoning block at a time
        // So this should not affect the state
441
442
        let result3 = parser
            .parse_reasoning_streaming_incremental(" <think>new reasoning</think> final", &[]);
443
444
445
446
        assert_eq!(result3.normal_text, " <think>new reasoning</think> final");
        assert_eq!(result3.reasoning_text, "");
    }
}