mod.rs 4.47 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
mod base_parser;
mod deepseek_r1_parser;
6
mod gpt_oss_parser;
7
8

// Re-export main types and functions for convenience
9
pub use base_parser::BasicReasoningParser;
10
pub use deepseek_r1_parser::DeepseekR1ReasoningParser;
11
pub use gpt_oss_parser::GptOssReasoningParser;
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

#[derive(Debug, Clone, Default)]
pub struct ParserResult {
    /// The normal text outside of reasoning blocks.
    pub normal_text: String,

    /// The extracted reasoning text from within reasoning blocks.
    pub reasoning_text: String,
}

impl ParserResult {
    pub fn get_some_reasoning(&self) -> Option<String> {
        if self.reasoning_text.is_empty() {
            None
        } else {
            Some(self.reasoning_text.clone())
        }
    }

    pub fn get_some_normal_text(&self) -> Option<String> {
        if self.normal_text.is_empty() {
            None
        } else {
            Some(self.normal_text.clone())
        }
    }
}

pub trait ReasoningParser: Send + std::fmt::Debug {
    /// Parses a standalone, non-streaming input chunk. Implementations may reset or ignore
    /// internal streaming state and should return the split of normal vs reasoning text for
    /// this complete input. Marker tokens must not be included in either output.
44
    fn detect_and_parse_reasoning(&mut self, text: &str, token_ids: &[u32]) -> ParserResult;
45
46
47
48

    /// Parses a streaming chunk and updates internal state. The return value should be the
    /// delta: only the newly discovered normal and reasoning text attributable to this chunk
    /// (not the cumulative totals). Marker tokens must not be included in either output.
49
50
51
52
53
    fn parse_reasoning_streaming_incremental(
        &mut self,
        text: &str,
        token_ids: &[u32],
    ) -> ParserResult;
54
55
56
57
58
59
60
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum ReasoningParserType {
    DeepseekR1,
    Basic,
61
    GptOss,
62
63
64
65
66
67
68
69
}

#[derive(std::fmt::Debug)]
pub struct ReasoningParserWrapper {
    parser: Box<dyn ReasoningParser>,
}

impl ReasoningParser for ReasoningParserWrapper {
70
71
    fn detect_and_parse_reasoning(&mut self, text: &str, token_ids: &[u32]) -> ParserResult {
        self.parser.detect_and_parse_reasoning(text, token_ids)
72
73
    }

74
75
76
77
78
79
80
    fn parse_reasoning_streaming_incremental(
        &mut self,
        text: &str,
        token_ids: &[u32],
    ) -> ParserResult {
        self.parser
            .parse_reasoning_streaming_incremental(text, token_ids)
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    }
}

impl ReasoningParserType {
    pub fn get_reasoning_parser(self) -> ReasoningParserWrapper {
        match self {
            ReasoningParserType::DeepseekR1 => ReasoningParserWrapper {
                parser: Box::new(DeepseekR1ReasoningParser::new()),
            },
            ReasoningParserType::Basic => ReasoningParserWrapper {
                parser: Box::new(BasicReasoningParser::new(
                    "<think>".into(),
                    "</think>".into(),
                    false,
                    true,
                )),
            },
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
            ReasoningParserType::GptOss => match GptOssReasoningParser::new() {
                Ok(parser) => ReasoningParserWrapper {
                    parser: Box::new(parser),
                },
                Err(e) => {
                    tracing::warn!(
                        "GptOssReasoningParser could not be initialized, falling back to Basic Reasoning Parser: {e}"
                    );
                    ReasoningParserWrapper {
                        parser: Box::new(BasicReasoningParser::new(
                            "<think>".into(),
                            "</think>".into(),
                            false,
                            true,
                        )),
                    }
                }
            },
116
117
        }
    }
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

    pub fn get_reasoning_parser_from_name(name: &str) -> ReasoningParserWrapper {
        tracing::debug!("Selected reasoning parser: {}", name);
        match name.to_lowercase().as_str() {
            "deepseek_r1" => Self::DeepseekR1.get_reasoning_parser(),
            "basic" => Self::Basic.get_reasoning_parser(),
            "gpt_oss" => Self::GptOss.get_reasoning_parser(),
            _ => {
                tracing::warn!(
                    "Unknown reasoning parser type '{}', falling back to Basic Reasoning Parser",
                    name
                );
                Self::Basic.get_reasoning_parser()
            }
        }
    }
134
}