mod.rs 7.43 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
// SPDX-License-Identifier: Apache-2.0
3
4
use std::collections::HashMap;
use std::sync::OnceLock;
5

6
mod base_parser;
7
mod gpt_oss_parser;
8
mod granite_parser;
9
10

// Re-export main types and functions for convenience
11
pub use base_parser::BasicReasoningParser;
12
pub use gpt_oss_parser::GptOssReasoningParser;
13
pub use granite_parser::GraniteReasoningParser;
14

15
16
17
18
19
20
21
22
23
24
25
26
27
28
static REASONING_PARSER_MAP: OnceLock<HashMap<&'static str, ReasoningParserType>> = OnceLock::new();

/// Initialize the global reasoning parser map
fn get_reasoning_parser_map() -> &'static HashMap<&'static str, ReasoningParserType> {
    REASONING_PARSER_MAP.get_or_init(|| {
        let mut map = HashMap::new();
        map.insert("deepseek_r1", ReasoningParserType::DeepseekR1);
        map.insert("basic", ReasoningParserType::Basic);
        map.insert("gpt_oss", ReasoningParserType::GptOss);
        map.insert("qwen3", ReasoningParserType::Qwen);
        map.insert("nemotron_deci", ReasoningParserType::NemotronDeci);
        map.insert("kimi", ReasoningParserType::Kimi);
        map.insert("step3", ReasoningParserType::Step3);
        map.insert("mistral", ReasoningParserType::Mistral);
29
        map.insert("granite", ReasoningParserType::Granite);
30
31
32
33
34
35
36
37
38
        map
    })
}

/// Get all available reasoning parser names
pub fn get_available_reasoning_parsers() -> Vec<&'static str> {
    get_reasoning_parser_map().keys().copied().collect()
}

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#[derive(Debug, Clone, Default)]
pub struct ParserResult {
    /// The normal text outside of reasoning blocks.
    pub normal_text: String,

    /// The extracted reasoning text from within reasoning blocks.
    pub reasoning_text: String,
}

impl ParserResult {
    pub fn get_some_reasoning(&self) -> Option<String> {
        if self.reasoning_text.is_empty() {
            None
        } else {
            Some(self.reasoning_text.clone())
        }
    }

    pub fn get_some_normal_text(&self) -> Option<String> {
        if self.normal_text.is_empty() {
            None
        } else {
            Some(self.normal_text.clone())
        }
    }
}

pub trait ReasoningParser: Send + std::fmt::Debug {
    /// Parses a standalone, non-streaming input chunk. Implementations may reset or ignore
    /// internal streaming state and should return the split of normal vs reasoning text for
    /// this complete input. Marker tokens must not be included in either output.
70
    fn detect_and_parse_reasoning(&mut self, text: &str, token_ids: &[u32]) -> ParserResult;
71
72
73
74

    /// Parses a streaming chunk and updates internal state. The return value should be the
    /// delta: only the newly discovered normal and reasoning text attributable to this chunk
    /// (not the cumulative totals). Marker tokens must not be included in either output.
75
76
77
78
79
    fn parse_reasoning_streaming_incremental(
        &mut self,
        text: &str,
        token_ids: &[u32],
    ) -> ParserResult;
80
81
82
83
84
85
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum ReasoningParserType {
    DeepseekR1,
86
    Step3,
87
    Basic,
88
    GptOss,
89
90
91
92
    Qwen,
    NemotronDeci,
    Kimi,
    Mistral,
93
    Granite,
94
95
96
97
98
99
100
101
}

#[derive(std::fmt::Debug)]
pub struct ReasoningParserWrapper {
    parser: Box<dyn ReasoningParser>,
}

impl ReasoningParser for ReasoningParserWrapper {
102
103
    fn detect_and_parse_reasoning(&mut self, text: &str, token_ids: &[u32]) -> ParserResult {
        self.parser.detect_and_parse_reasoning(text, token_ids)
104
105
    }

106
107
108
109
110
111
112
    fn parse_reasoning_streaming_incremental(
        &mut self,
        text: &str,
        token_ids: &[u32],
    ) -> ParserResult {
        self.parser
            .parse_reasoning_streaming_incremental(text, token_ids)
113
114
115
116
117
    }
}

impl ReasoningParserType {
    pub fn get_reasoning_parser(self) -> ReasoningParserWrapper {
118
119
120
121
        let basic_parser =
            BasicReasoningParser::new("<think>".into(), "</think>".into(), false, true);
        let force_reasoning_basic_parser =
            BasicReasoningParser::new("<think>".into(), "</think>".into(), true, true);
122
123
        match self {
            ReasoningParserType::DeepseekR1 => ReasoningParserWrapper {
124
125
126
127
                parser: Box::new(force_reasoning_basic_parser),
            },
            ReasoningParserType::Step3 => ReasoningParserWrapper {
                parser: Box::new(force_reasoning_basic_parser),
128
129
            },
            ReasoningParserType::Basic => ReasoningParserWrapper {
130
131
132
133
134
135
136
137
138
                parser: Box::new(basic_parser),
            },
            ReasoningParserType::Qwen => ReasoningParserWrapper {
                parser: Box::new(basic_parser),
            },
            ReasoningParserType::NemotronDeci => ReasoningParserWrapper {
                parser: Box::new(basic_parser),
            },
            ReasoningParserType::Kimi => ReasoningParserWrapper {
139
                parser: Box::new(BasicReasoningParser::new(
140
141
                    "◁think▷".into(),
                    "◁/think▷".into(),
142
143
144
145
                    false,
                    true,
                )),
            },
146
147
148
149
150
151
152
153
            ReasoningParserType::Mistral => ReasoningParserWrapper {
                parser: Box::new(BasicReasoningParser::new(
                    "[THINK]".into(),
                    "[/THINK]".into(),
                    true,
                    true,
                )),
            },
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
            ReasoningParserType::GptOss => match GptOssReasoningParser::new() {
                Ok(parser) => ReasoningParserWrapper {
                    parser: Box::new(parser),
                },
                Err(e) => {
                    tracing::warn!(
                        "GptOssReasoningParser could not be initialized, falling back to Basic Reasoning Parser: {e}"
                    );
                    ReasoningParserWrapper {
                        parser: Box::new(BasicReasoningParser::new(
                            "<think>".into(),
                            "</think>".into(),
                            false,
                            true,
                        )),
                    }
                }
            },
172
173
174
            ReasoningParserType::Granite => ReasoningParserWrapper {
                parser: Box::new(GraniteReasoningParser::new()),
            },
175
176
        }
    }
177
178

    pub fn get_reasoning_parser_from_name(name: &str) -> ReasoningParserWrapper {
179
180
181
182
183
184
185
186
        tracing::debug!("Selected reasoning parser: {}", name);

        let parser_map = get_reasoning_parser_map();
        let normalized_name = name.to_lowercase();

        match parser_map.get(normalized_name.as_str()) {
            Some(parser_type) => parser_type.get_reasoning_parser(),
            None => {
187
                tracing::warn!(
188
189
                    parser_name = name,
                    "Unknown reasoning parser type, falling back to Basic Reasoning Parser",
190
191
192
193
194
                );
                Self::Basic.get_reasoning_parser()
            }
        }
    }
195
}
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_get_available_reasoning_parsers() {
        let parsers = get_available_reasoning_parsers();
        assert!(!parsers.is_empty());
        // Update this list when adding a new parser
        let available_parsers = [
            "deepseek_r1",
            "basic",
            "gpt_oss",
            "qwen3",
            "nemotron_deci",
            "kimi",
            "step3",
            "mistral",
215
            "granite",
216
217
218
219
220
221
        ];
        for parser in available_parsers {
            assert!(parsers.contains(&parser));
        }
    }
}