base_parser.rs 17.3 KB
Newer Older
1
2
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
3
use tracing as log;
4

5
use crate::{ParserResult, ReasoningParser};
6

7
8
#[derive(Default, Debug, Clone)]
pub struct BasicReasoningParser {
9
10
11
12
13
14
15
16
    think_start_token: String,
    think_end_token: String,
    _in_reasoning: bool,
    stream_reasoning: bool,
    _buffer: String,
    stripped_think_start: bool,
}

17
impl BasicReasoningParser {
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    pub fn new(
        think_start_token: String,
        think_end_token: String,
        force_reasoning: bool,
        stream_reasoning: bool,
    ) -> Self {
        Self {
            think_start_token,
            think_end_token,
            _in_reasoning: force_reasoning,
            stream_reasoning,
            _buffer: String::new(),
            stripped_think_start: false,
        }
    }
}

35
36
impl ReasoningParser for BasicReasoningParser {
    fn detect_and_parse_reasoning(&self, text: &str) -> ParserResult {
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        log::debug!("detect_and_parse_reasoning called with text: {:?}", text);

        let in_reasoning = self._in_reasoning || text.contains(&self.think_start_token);
        log::debug!("in_reasoning: {}", in_reasoning);

        if !in_reasoning {
            log::debug!("No reasoning detected, returning normal text.");
            return ParserResult {
                normal_text: text.to_string(),
                reasoning_text: String::new(),
            };
        }

        // The text is considered to be in a reasoning block.
        let processed_text = text.replace(&self.think_start_token, "").trim().to_string();
        log::debug!(
            "Processed text after removing think_start_token: {:?}",
            processed_text
        );

        if !processed_text.contains(&self.think_end_token) {
            log::debug!(
                "Reasoning truncated, think_end_token not found. Returning reasoning text."
            );
            // Assume reasoning was truncated before `think_end_token`
            return ParserResult {
                normal_text: String::new(),
                reasoning_text: processed_text,
            };
        }

        // Extract reasoning content
        let splits: Vec<&str> = processed_text.splitn(2, &self.think_end_token).collect();
        let reasoning_text = splits.first().unwrap_or(&"").to_string();
        let normal_text = splits
            .get(1)
            .map(|s| s.trim().to_string())
            .unwrap_or_default();

        log::debug!("Extracted reasoning_text: {:?}", reasoning_text);
        log::debug!("Extracted normal_text: {:?}", normal_text);

        ParserResult {
            normal_text,
            reasoning_text,
        }
    }

    fn parse_reasoning_streaming_incremental(&mut self, text: &str) -> ParserResult {
        // Incrementally parse the streaming text
        self._buffer.push_str(text);
        let mut current_text = self._buffer.to_string();
        // If the current text is a prefix of the think token, keep buffering

        log::debug!(
            "parse_reasoning_streaming_incremental called with text: {:?}",
            text
        );
        log::debug!("current buffer: {:?}", self._buffer);
        log::debug!("current_text: {:?}", current_text);
        log::debug!(
            "in_reasoning: {}, stripped_think_start: {}, stream_reasoning: {}",
            self._in_reasoning,
            self.stripped_think_start,
            self.stream_reasoning
        );

        if self.think_start_token.starts_with(&current_text)
            && self.think_start_token.as_str() != current_text.as_str()
        {
            return ParserResult {
                normal_text: String::new(),
                reasoning_text: String::new(),
            };
        }
        if self.think_end_token.starts_with(&current_text)
            && self.think_end_token.as_str() != current_text.as_str()
        {
            return ParserResult {
                normal_text: String::new(),
                reasoning_text: String::new(),
            };
        }

        // Strip `<think>` token if present
        if !self.stripped_think_start && current_text.contains(&self.think_start_token) {
            current_text = current_text.replace(&self.think_start_token, "");
            self._buffer = current_text.to_string();
            self.stripped_think_start = true;
            self._in_reasoning = true;
        }
        // Handle end of reasoning block
        let mut think_end_idx = current_text.len();
        if self._in_reasoning {
            think_end_idx = current_text
                .find(&self.think_end_token)
                .unwrap_or(current_text.len());
        }
        if self._in_reasoning && think_end_idx < current_text.len() {
            let reasoning_text = &current_text[..think_end_idx];
            self._buffer.clear();
            self._in_reasoning = false;
            let start_idx = think_end_idx + self.think_end_token.len();
            let normal_text = if start_idx < current_text.len() {
                &current_text[start_idx..]
            } else {
                ""
            };
            return ParserResult {
                normal_text: normal_text.to_string(),
                reasoning_text: reasoning_text.trim().to_string(),
            };
        }
        // Continue with reasoning content
        if self._in_reasoning && self.stream_reasoning {
            // Stream the content immediately
            let reasoning_text = current_text;
            self._buffer.clear();
            ParserResult {
                normal_text: String::new(),
                reasoning_text,
            }
        } else if !self._in_reasoning {
            // If we're not in a reasoning block return as normal text
            let normal_text = current_text;
            self._buffer.clear();
            ParserResult {
                normal_text,
                reasoning_text: String::new(),
            }
        } else {
            // If we are in a reasoning block but no end token is found, return the current buffer
            ParserResult {
                normal_text: String::new(),
                reasoning_text: String::new(),
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_detect_and_parse_reasoning_reasoning() {
183
184
        let parser =
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
185
186
187
188
189
190
191
        let result =
            parser.detect_and_parse_reasoning("<think>with reasoning</think> and more text.");
        assert_eq!(result.normal_text, "and more text.");
        assert_eq!(result.reasoning_text, "with reasoning");
    }
    #[test]
    fn test_detect_and_parse_reasoning_reasoning_no_reasoning() {
192
193
        let parser =
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
194
195
196
197
198
199
        let result = parser.detect_and_parse_reasoning("This is a test without reasoning.");
        assert_eq!(result.normal_text, "This is a test without reasoning.");
        assert_eq!(result.reasoning_text, "");
    }
    #[test]
    fn test_detect_and_parse_reasoning_reasoning_truncated_reasoning() {
200
201
        let parser =
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
202
203
204
205
206
207
208
209
        let result = parser.detect_and_parse_reasoning("<think>with truncated reasoning");
        assert_eq!(result.normal_text, "");
        assert_eq!(result.reasoning_text, "with truncated reasoning");
    }

    #[test]
    fn test_parse_reasoning_streaming_incremental() {
        let mut parser =
210
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
211
212
213
214
215
216
217
218
        let result = parser.parse_reasoning_streaming_incremental("<thi");
        assert_eq!(result.normal_text, "");
        assert_eq!(result.reasoning_text, "");
    }

    #[test]
    fn test_parse_reasoning_streaming_incremental_complete() {
        let mut parser =
219
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
220
221
222
223
224
225
226
227
228
        let result = parser
            .parse_reasoning_streaming_incremental("<think>with reasoning</think> and more text.");
        assert_eq!(result.normal_text, " and more text.");
        assert_eq!(result.reasoning_text, "with reasoning");
    }

    #[test]
    fn test_parse_reasoning_streaming_incremental_no_end_token() {
        let mut parser =
229
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), true, true);
230
231
232
233
234
235
236
        let result = parser.parse_reasoning_streaming_incremental("<think>with reasoning");
        assert_eq!(result.normal_text, "");
        assert_eq!(result.reasoning_text, "with reasoning");
    }

    #[test]
    fn test_detect_and_parse_reasoning_multiple_reasoning_blocks() {
237
238
        let parser =
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
239
240
241
242
243
244
245
246
247
248
249
        let result = parser.detect_and_parse_reasoning(
            "<think>first reasoning</think> middle <think>second reasoning</think> end",
        );
        // The current implementation only handles the first occurrence properly
        assert_eq!(result.normal_text, "middle second reasoning</think> end");
        assert_eq!(result.reasoning_text, "first reasoning");
    }

    #[test]
    fn test_streaming_multiple_reasoning_blocks() {
        let mut parser =
250
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        let result1 =
            parser.parse_reasoning_streaming_incremental("<think>first reasoning</think> middle");
        assert_eq!(result1.normal_text, " middle");
        assert_eq!(result1.reasoning_text, "first reasoning");

        // Basic parser assumes only one reasoning block at a time
        let result2 =
            parser.parse_reasoning_streaming_incremental(" <think>second reasoning</think> end");
        assert_eq!(result2.normal_text, " <think>second reasoning</think> end");
        assert_eq!(result2.reasoning_text, "");
    }

    #[test]
    fn test_partial_token_matching_opening_tag() {
        let mut parser =
266
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282

        // Feed partial opening tag
        let result1 = parser.parse_reasoning_streaming_incremental("<th");
        assert_eq!(result1.normal_text, "");
        assert_eq!(result1.reasoning_text, "");

        // Complete the opening tag and add content
        let result2 = parser
            .parse_reasoning_streaming_incremental("ink>reasoning content</think> normal text");
        assert_eq!(result2.normal_text, " normal text");
        assert_eq!(result2.reasoning_text, "reasoning content");
    }

    #[test]
    fn test_partial_token_matching_closing_tag() {
        let mut parser =
283
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298

        // Start with complete opening and partial content
        let result1 = parser.parse_reasoning_streaming_incremental("<think>reasoning content</th");
        assert_eq!(result1.normal_text, "");
        assert_eq!(result1.reasoning_text, "");

        // Complete the closing tag
        let result2 = parser.parse_reasoning_streaming_incremental("ink> normal text");
        assert_eq!(result2.normal_text, " normal text");
        assert_eq!(result2.reasoning_text, "reasoning content");
    }

    #[test]
    fn test_buffer_state_persistence_across_calls() {
        let mut parser =
299
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324

        // First call - partial opening tag
        let result1 = parser.parse_reasoning_streaming_incremental("<th");
        assert_eq!(result1.normal_text, "");
        assert_eq!(result1.reasoning_text, "");

        // Second call - complete opening tag, start reasoning
        let result2 = parser.parse_reasoning_streaming_incremental("ink>part1 ");
        assert_eq!(result2.normal_text, "");
        assert_eq!(result2.reasoning_text, "");

        // Third call - more reasoning content
        let result3 = parser.parse_reasoning_streaming_incremental("part2 ");
        assert_eq!(result3.normal_text, "");
        assert_eq!(result3.reasoning_text, "");

        // Fourth call - end reasoning and normal text
        let result4 = parser.parse_reasoning_streaming_incremental("part3</think> normal");
        assert_eq!(result4.normal_text, " normal");
        assert_eq!(result4.reasoning_text, "part1 part2 part3");
    }

    #[test]
    fn test_streaming_with_stream_reasoning_enabled() {
        let mut parser =
325
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344

        // Start reasoning block
        let result1 = parser.parse_reasoning_streaming_incremental("<think>reasoning ");
        assert_eq!(result1.normal_text, "");
        assert_eq!(result1.reasoning_text, "reasoning ");

        // Continue streaming reasoning
        let result2 = parser.parse_reasoning_streaming_incremental("content ");
        assert_eq!(result2.normal_text, "");
        assert_eq!(result2.reasoning_text, "content ");

        // End reasoning block
        let result3 = parser.parse_reasoning_streaming_incremental("more</think> normal");
        assert_eq!(result3.normal_text, " normal");
        assert_eq!(result3.reasoning_text, "more");
    }

    #[test]
    fn test_nested_reasoning_blocks() {
345
346
        let parser =
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
347
348
349
350
351
352
353
354
355
356
357
        let result = parser.detect_and_parse_reasoning(
            "<think>outer <think>inner</think> reasoning</think> normal",
        );
        // Current implementation should handle this by finding the first closing tag
        assert_eq!(result.normal_text, "reasoning</think> normal");
        // All <think> tags are stripped, so <think>inner is not included
        assert_eq!(result.reasoning_text, "outer inner");
    }

    #[test]
    fn test_malformed_missing_closing_tag() {
358
359
        let parser =
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
360
361
362
363
364
365
366
        let result = parser.detect_and_parse_reasoning("<think>reasoning without closing tag");
        assert_eq!(result.normal_text, "");
        assert_eq!(result.reasoning_text, "reasoning without closing tag");
    }

    #[test]
    fn test_malformed_stray_closing_tag() {
367
368
        let parser =
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
369
370
371
372
373
374
375
        let result = parser.detect_and_parse_reasoning("normal text</think> more normal");
        assert_eq!(result.normal_text, "normal text</think> more normal");
        assert_eq!(result.reasoning_text, "");
    }

    #[test]
    fn test_malformed_multiple_opening_tags() {
376
377
        let parser =
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
378
379
380
381
382
383
384
385
386
        let result = parser
            .detect_and_parse_reasoning("<think>first <think>second reasoning</think> normal");
        // Should handle by replacing all opening tags and using first closing tag
        assert_eq!(result.normal_text, "normal");
        assert_eq!(result.reasoning_text, "first second reasoning");
    }

    #[test]
    fn test_empty_reasoning_block() {
387
388
        let parser =
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
389
390
391
392
393
394
395
        let result = parser.detect_and_parse_reasoning("<think></think> normal text");
        assert_eq!(result.normal_text, "normal text");
        assert_eq!(result.reasoning_text, "");
    }

    #[test]
    fn test_whitespace_only_reasoning_block() {
396
397
        let parser =
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
398
399
400
401
402
403
404
        let result = parser.detect_and_parse_reasoning("<think>   \n\t  </think> normal text");
        assert_eq!(result.normal_text, "normal text");
        assert_eq!(result.reasoning_text, ""); // Should be empty after trim
    }

    #[test]
    fn test_force_reasoning_mode() {
405
406
        let parser =
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), true, true);
407
408
409
410
411
412
413
414
        let result = parser.detect_and_parse_reasoning("no think tags here");
        assert_eq!(result.normal_text, "");
        assert_eq!(result.reasoning_text, "no think tags here");
    }

    #[test]
    fn test_streaming_reset_state_after_complete_block() {
        let mut parser =
415
            BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435

        // Process complete reasoning block
        let result1 =
            parser.parse_reasoning_streaming_incremental("<think>reasoning</think> normal");
        assert_eq!(result1.normal_text, " normal");
        assert_eq!(result1.reasoning_text, "reasoning");

        // Process normal text - should not be affected by previous state
        let result2 = parser.parse_reasoning_streaming_incremental(" more normal text");
        assert_eq!(result2.normal_text, " more normal text");
        assert_eq!(result2.reasoning_text, "");

        // Basic parser does not expect more than one reasoning block at a time
        // So this should not affect the state
        let result3 =
            parser.parse_reasoning_streaming_incremental(" <think>new reasoning</think> final");
        assert_eq!(result3.normal_text, " <think>new reasoning</think> final");
        assert_eq!(result3.reasoning_text, "");
    }
}