traits.rs 3.74 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
use std::fmt;

/// Result of parsing text for reasoning content.
#[derive(Debug, Clone, Default, PartialEq)]
pub struct ParserResult {
    /// The normal text outside of reasoning blocks.
    pub normal_text: String,

    /// The extracted reasoning text from within reasoning blocks.
    pub reasoning_text: String,
}

impl ParserResult {
    /// Create a new ParserResult with the given normal and reasoning text.
    pub fn new(normal_text: String, reasoning_text: String) -> Self {
        Self {
            normal_text,
            reasoning_text,
        }
    }

    /// Create a result with only normal text.
    pub fn normal(text: String) -> Self {
        Self {
            normal_text: text,
            reasoning_text: String::new(),
        }
    }

    /// Create a result with only reasoning text.
    pub fn reasoning(text: String) -> Self {
        Self {
            normal_text: String::new(),
            reasoning_text: text,
        }
    }

    /// Check if this result contains any text.
    pub fn is_empty(&self) -> bool {
        self.normal_text.is_empty() && self.reasoning_text.is_empty()
    }
}

/// Trait for parsing reasoning content from LLM outputs.
pub trait ReasoningParser: Send + Sync {
    /// Detects and parses reasoning from the input text (one-time parsing).
    ///
    /// This method is used for non-streaming scenarios where the complete
    /// text is available at once.
    ///
    /// Returns an error if the text exceeds buffer limits or contains invalid UTF-8.
    fn detect_and_parse_reasoning(&mut self, text: &str) -> Result<ParserResult, ParseError>;

    /// Parses reasoning incrementally from streaming input.
    ///
    /// This method maintains internal state across calls to handle partial
    /// tokens and chunk boundaries correctly.
    ///
    /// Returns an error if the buffer exceeds max_buffer_size.
    fn parse_reasoning_streaming_incremental(
        &mut self,
        text: &str,
    ) -> Result<ParserResult, ParseError>;

    /// Reset the parser state for reuse.
    ///
    /// This should clear any buffers and reset flags to initial state.
    fn reset(&mut self);

    /// Get the model type this parser is designed for.
    fn model_type(&self) -> &str;
}

/// Error types for reasoning parsing operations.
#[derive(Debug, thiserror::Error)]
pub enum ParseError {
    #[error("Invalid UTF-8 in stream: {0}")]
    Utf8Error(#[from] std::str::Utf8Error),

    #[error("Buffer overflow: {0} bytes exceeds maximum")]
    BufferOverflow(usize),

    #[error("Unknown model type: {0}")]
    UnknownModel(String),

    #[error("Parser configuration error: {0}")]
    ConfigError(String),
}

/// Configuration for parser behavior.
#[derive(Debug, Clone)]
pub struct ParserConfig {
    /// The token that marks the start of reasoning content.
    pub think_start_token: String,

    /// The token that marks the end of reasoning content.
    pub think_end_token: String,

    /// Whether to force all text to be treated as reasoning.
    pub force_reasoning: bool,

    /// Whether to stream reasoning content as it arrives.
    pub stream_reasoning: bool,

    /// Maximum buffer size in bytes.
    pub max_buffer_size: usize,
}

impl Default for ParserConfig {
    fn default() -> Self {
        Self {
            think_start_token: "<think>".to_string(),
            think_end_token: "</think>".to_string(),
            force_reasoning: false,
            stream_reasoning: true,
            max_buffer_size: 65536, // 64KB default
        }
    }
}

impl fmt::Display for ParserResult {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(
            f,
            "ParserResult {{ normal: {} chars, reasoning: {} chars }}",
            self.normal_text.len(),
            self.reasoning_text.len()
        )
    }
}