mod.rs 5.84 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
mod base_parser;
5
mod gpt_oss_parser;
6
7

// Re-export main types and functions for convenience
8
pub use base_parser::BasicReasoningParser;
9
pub use gpt_oss_parser::GptOssReasoningParser;
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

#[derive(Debug, Clone, Default)]
pub struct ParserResult {
    /// The normal text outside of reasoning blocks.
    pub normal_text: String,

    /// The extracted reasoning text from within reasoning blocks.
    pub reasoning_text: String,
}

impl ParserResult {
    pub fn get_some_reasoning(&self) -> Option<String> {
        if self.reasoning_text.is_empty() {
            None
        } else {
            Some(self.reasoning_text.clone())
        }
    }

    pub fn get_some_normal_text(&self) -> Option<String> {
        if self.normal_text.is_empty() {
            None
        } else {
            Some(self.normal_text.clone())
        }
    }
}

pub trait ReasoningParser: Send + std::fmt::Debug {
    /// Parses a standalone, non-streaming input chunk. Implementations may reset or ignore
    /// internal streaming state and should return the split of normal vs reasoning text for
    /// this complete input. Marker tokens must not be included in either output.
42
    fn detect_and_parse_reasoning(&mut self, text: &str, token_ids: &[u32]) -> ParserResult;
43
44
45
46

    /// Parses a streaming chunk and updates internal state. The return value should be the
    /// delta: only the newly discovered normal and reasoning text attributable to this chunk
    /// (not the cumulative totals). Marker tokens must not be included in either output.
47
48
49
50
51
    fn parse_reasoning_streaming_incremental(
        &mut self,
        text: &str,
        token_ids: &[u32],
    ) -> ParserResult;
52
53
54
55
56
57
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum ReasoningParserType {
    DeepseekR1,
58
    Step3,
59
    Basic,
60
    GptOss,
61
62
63
64
    Qwen,
    NemotronDeci,
    Kimi,
    Mistral,
65
66
67
68
69
70
71
72
}

#[derive(std::fmt::Debug)]
pub struct ReasoningParserWrapper {
    parser: Box<dyn ReasoningParser>,
}

impl ReasoningParser for ReasoningParserWrapper {
73
74
    fn detect_and_parse_reasoning(&mut self, text: &str, token_ids: &[u32]) -> ParserResult {
        self.parser.detect_and_parse_reasoning(text, token_ids)
75
76
    }

77
78
79
80
81
82
83
    fn parse_reasoning_streaming_incremental(
        &mut self,
        text: &str,
        token_ids: &[u32],
    ) -> ParserResult {
        self.parser
            .parse_reasoning_streaming_incremental(text, token_ids)
84
85
86
87
88
    }
}

impl ReasoningParserType {
    pub fn get_reasoning_parser(self) -> ReasoningParserWrapper {
89
90
91
92
        let basic_parser =
            BasicReasoningParser::new("<think>".into(), "</think>".into(), false, true);
        let force_reasoning_basic_parser =
            BasicReasoningParser::new("<think>".into(), "</think>".into(), true, true);
93
94
        match self {
            ReasoningParserType::DeepseekR1 => ReasoningParserWrapper {
95
96
97
98
                parser: Box::new(force_reasoning_basic_parser),
            },
            ReasoningParserType::Step3 => ReasoningParserWrapper {
                parser: Box::new(force_reasoning_basic_parser),
99
100
            },
            ReasoningParserType::Basic => ReasoningParserWrapper {
101
102
103
104
105
106
107
108
109
                parser: Box::new(basic_parser),
            },
            ReasoningParserType::Qwen => ReasoningParserWrapper {
                parser: Box::new(basic_parser),
            },
            ReasoningParserType::NemotronDeci => ReasoningParserWrapper {
                parser: Box::new(basic_parser),
            },
            ReasoningParserType::Kimi => ReasoningParserWrapper {
110
                parser: Box::new(BasicReasoningParser::new(
111
112
                    "◁think▷".into(),
                    "◁/think▷".into(),
113
114
115
116
                    false,
                    true,
                )),
            },
117
118
119
120
121
122
123
124
            ReasoningParserType::Mistral => ReasoningParserWrapper {
                parser: Box::new(BasicReasoningParser::new(
                    "[THINK]".into(),
                    "[/THINK]".into(),
                    true,
                    true,
                )),
            },
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
            ReasoningParserType::GptOss => match GptOssReasoningParser::new() {
                Ok(parser) => ReasoningParserWrapper {
                    parser: Box::new(parser),
                },
                Err(e) => {
                    tracing::warn!(
                        "GptOssReasoningParser could not be initialized, falling back to Basic Reasoning Parser: {e}"
                    );
                    ReasoningParserWrapper {
                        parser: Box::new(BasicReasoningParser::new(
                            "<think>".into(),
                            "</think>".into(),
                            false,
                            true,
                        )),
                    }
                }
            },
143
144
        }
    }
145
146

    pub fn get_reasoning_parser_from_name(name: &str) -> ReasoningParserWrapper {
147
        tracing::debug!(parser_name = name, "Selected reasoning parser");
148
149
150
151
        match name.to_lowercase().as_str() {
            "deepseek_r1" => Self::DeepseekR1.get_reasoning_parser(),
            "basic" => Self::Basic.get_reasoning_parser(),
            "gpt_oss" => Self::GptOss.get_reasoning_parser(),
152
153
154
155
156
            "qwen3" => Self::Qwen.get_reasoning_parser(),
            "nemotron_deci" => Self::NemotronDeci.get_reasoning_parser(),
            "kimi" => Self::Kimi.get_reasoning_parser(),
            "step3" => Self::Step3.get_reasoning_parser(),
            "mistral" => Self::Mistral.get_reasoning_parser(),
157
158
            _ => {
                tracing::warn!(
159
160
                    parser_name = name,
                    "Unknown reasoning parser type, falling back to Basic Reasoning Parser",
161
162
163
164
165
                );
                Self::Basic.get_reasoning_parser()
            }
        }
    }
166
}