mod.rs 8.13 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
// SPDX-License-Identifier: Apache-2.0
3
4
use std::collections::HashMap;
use std::sync::OnceLock;
5

6
mod base_parser;
7
mod gpt_oss_parser;
8
mod granite_parser;
9
mod minimax_append_think_parser;
10
11

// Re-export main types and functions for convenience
12
pub use base_parser::BasicReasoningParser;
13
pub use gpt_oss_parser::GptOssReasoningParser;
14
pub use granite_parser::GraniteReasoningParser;
15
pub use minimax_append_think_parser::MiniMaxAppendThinkParser;
16

17
18
19
20
21
22
23
24
25
26
27
28
29
30
static REASONING_PARSER_MAP: OnceLock<HashMap<&'static str, ReasoningParserType>> = OnceLock::new();

/// Initialize the global reasoning parser map
fn get_reasoning_parser_map() -> &'static HashMap<&'static str, ReasoningParserType> {
    REASONING_PARSER_MAP.get_or_init(|| {
        let mut map = HashMap::new();
        map.insert("deepseek_r1", ReasoningParserType::DeepseekR1);
        map.insert("basic", ReasoningParserType::Basic);
        map.insert("gpt_oss", ReasoningParserType::GptOss);
        map.insert("qwen3", ReasoningParserType::Qwen);
        map.insert("nemotron_deci", ReasoningParserType::NemotronDeci);
        map.insert("kimi", ReasoningParserType::Kimi);
        map.insert("step3", ReasoningParserType::Step3);
        map.insert("mistral", ReasoningParserType::Mistral);
31
        map.insert("granite", ReasoningParserType::Granite);
32
        map.insert("nemotron_nano", ReasoningParserType::NemotronDeci); // nemotron nano is <think>...</think>
33
        map.insert("glm45", ReasoningParserType::NemotronDeci); // GLM-4.5/5 is <think>...</think>, no force_reasoning
34
35
36
37
        map.insert(
            "minimax_append_think",
            ReasoningParserType::MiniMaxAppendThink,
        );
38
39
40
41
42
43
44
45
46
        map
    })
}

/// Get all available reasoning parser names
pub fn get_available_reasoning_parsers() -> Vec<&'static str> {
    get_reasoning_parser_map().keys().copied().collect()
}

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#[derive(Debug, Clone, Default)]
pub struct ParserResult {
    /// The normal text outside of reasoning blocks.
    pub normal_text: String,

    /// The extracted reasoning text from within reasoning blocks.
    pub reasoning_text: String,
}

impl ParserResult {
    pub fn get_some_reasoning(&self) -> Option<String> {
        if self.reasoning_text.is_empty() {
            None
        } else {
            Some(self.reasoning_text.clone())
        }
    }

    pub fn get_some_normal_text(&self) -> Option<String> {
        if self.normal_text.is_empty() {
            None
        } else {
            Some(self.normal_text.clone())
        }
    }
}

pub trait ReasoningParser: Send + std::fmt::Debug {
    /// Parses a standalone, non-streaming input chunk. Implementations may reset or ignore
    /// internal streaming state and should return the split of normal vs reasoning text for
    /// this complete input. Marker tokens must not be included in either output.
78
    fn detect_and_parse_reasoning(&mut self, text: &str, token_ids: &[u32]) -> ParserResult;
79
80
81
82

    /// Parses a streaming chunk and updates internal state. The return value should be the
    /// delta: only the newly discovered normal and reasoning text attributable to this chunk
    /// (not the cumulative totals). Marker tokens must not be included in either output.
83
84
85
86
87
    fn parse_reasoning_streaming_incremental(
        &mut self,
        text: &str,
        token_ids: &[u32],
    ) -> ParserResult;
88
89
90
91
92
93
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum ReasoningParserType {
    DeepseekR1,
94
    Step3,
95
    Basic,
96
    GptOss,
97
98
99
100
    Qwen,
    NemotronDeci,
    Kimi,
    Mistral,
101
    Granite,
102
    MiniMaxAppendThink,
103
104
105
106
107
108
109
110
}

#[derive(std::fmt::Debug)]
pub struct ReasoningParserWrapper {
    parser: Box<dyn ReasoningParser>,
}

impl ReasoningParser for ReasoningParserWrapper {
111
112
    fn detect_and_parse_reasoning(&mut self, text: &str, token_ids: &[u32]) -> ParserResult {
        self.parser.detect_and_parse_reasoning(text, token_ids)
113
114
    }

115
116
117
118
119
120
121
    fn parse_reasoning_streaming_incremental(
        &mut self,
        text: &str,
        token_ids: &[u32],
    ) -> ParserResult {
        self.parser
            .parse_reasoning_streaming_incremental(text, token_ids)
122
123
124
125
126
    }
}

impl ReasoningParserType {
    pub fn get_reasoning_parser(self) -> ReasoningParserWrapper {
127
128
129
130
        let basic_parser =
            BasicReasoningParser::new("<think>".into(), "</think>".into(), false, true);
        let force_reasoning_basic_parser =
            BasicReasoningParser::new("<think>".into(), "</think>".into(), true, true);
131
132
        match self {
            ReasoningParserType::DeepseekR1 => ReasoningParserWrapper {
133
134
135
136
                parser: Box::new(force_reasoning_basic_parser),
            },
            ReasoningParserType::Step3 => ReasoningParserWrapper {
                parser: Box::new(force_reasoning_basic_parser),
137
138
            },
            ReasoningParserType::Basic => ReasoningParserWrapper {
139
140
141
142
143
144
145
146
147
                parser: Box::new(basic_parser),
            },
            ReasoningParserType::Qwen => ReasoningParserWrapper {
                parser: Box::new(basic_parser),
            },
            ReasoningParserType::NemotronDeci => ReasoningParserWrapper {
                parser: Box::new(basic_parser),
            },
            ReasoningParserType::Kimi => ReasoningParserWrapper {
148
                parser: Box::new(BasicReasoningParser::new(
149
150
                    "◁think▷".into(),
                    "◁/think▷".into(),
151
152
153
154
                    false,
                    true,
                )),
            },
155
156
157
158
159
160
161
162
            ReasoningParserType::Mistral => ReasoningParserWrapper {
                parser: Box::new(BasicReasoningParser::new(
                    "[THINK]".into(),
                    "[/THINK]".into(),
                    true,
                    true,
                )),
            },
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
            ReasoningParserType::GptOss => match GptOssReasoningParser::new() {
                Ok(parser) => ReasoningParserWrapper {
                    parser: Box::new(parser),
                },
                Err(e) => {
                    tracing::warn!(
                        "GptOssReasoningParser could not be initialized, falling back to Basic Reasoning Parser: {e}"
                    );
                    ReasoningParserWrapper {
                        parser: Box::new(BasicReasoningParser::new(
                            "<think>".into(),
                            "</think>".into(),
                            false,
                            true,
                        )),
                    }
                }
            },
181
182
183
            ReasoningParserType::Granite => ReasoningParserWrapper {
                parser: Box::new(GraniteReasoningParser::new()),
            },
184
185
186
            ReasoningParserType::MiniMaxAppendThink => ReasoningParserWrapper {
                parser: Box::new(MiniMaxAppendThinkParser::new()),
            },
187
188
        }
    }
189
190

    pub fn get_reasoning_parser_from_name(name: &str) -> ReasoningParserWrapper {
191
192
193
194
195
196
197
198
        tracing::debug!("Selected reasoning parser: {}", name);

        let parser_map = get_reasoning_parser_map();
        let normalized_name = name.to_lowercase();

        match parser_map.get(normalized_name.as_str()) {
            Some(parser_type) => parser_type.get_reasoning_parser(),
            None => {
199
                tracing::warn!(
200
201
                    parser_name = name,
                    "Unknown reasoning parser type, falling back to Basic Reasoning Parser",
202
203
204
205
206
                );
                Self::Basic.get_reasoning_parser()
            }
        }
    }
207
}
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_get_available_reasoning_parsers() {
        let parsers = get_available_reasoning_parsers();
        assert!(!parsers.is_empty());
        // Update this list when adding a new parser
        let available_parsers = [
            "deepseek_r1",
            "basic",
            "gpt_oss",
            "qwen3",
            "nemotron_deci",
            "kimi",
            "step3",
            "mistral",
227
            "granite",
228
            "nemotron_nano",
229
            "glm45",
230
            "minimax_append_think",
231
232
233
234
235
236
        ];
        for parser in available_parsers {
            assert!(parsers.contains(&parser));
        }
    }
}