state.rs 5.96 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
use crate::tool_parser::types::{PartialToolCall, ToolCall};

/// Current phase of parsing
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ParsePhase {
    /// Looking for start of tool call
    Searching,
    /// Parsing function name
    InName,
    /// Parsing function arguments
    InArguments,
    /// Tool call complete
    Complete,
}

/// State for streaming parser
#[derive(Debug, Clone)]
pub struct ParseState {
    /// Buffer for accumulating input
    pub buffer: String,
    /// Position of last consumed character
    pub consumed: usize,
    /// Current partial tool being parsed
    pub partial_tool: Option<PartialToolCall>,
    /// Completed tool calls
    pub completed_tools: Vec<ToolCall>,
    /// Current parsing phase
    pub phase: ParsePhase,
    /// Bracket/brace depth for JSON parsing
    pub bracket_depth: i32,
    /// Whether currently inside a string literal
    pub in_string: bool,
    /// Whether next character should be escaped
    pub escape_next: bool,
    /// Current tool index (for streaming)
    pub tool_index: usize,
37
38
    /// Optional Harmony-specific streaming state (populated by token-aware parsers)
    pub harmony_stream: Option<HarmonyStreamState>,
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
}

impl ParseState {
    /// Create a new parse state
    pub fn new() -> Self {
        Self {
            buffer: String::new(),
            consumed: 0,
            partial_tool: None,
            completed_tools: Vec::new(),
            phase: ParsePhase::Searching,
            bracket_depth: 0,
            in_string: false,
            escape_next: false,
            tool_index: 0,
54
            harmony_stream: None,
55
56
57
58
59
60
61
62
63
64
        }
    }

    /// Reset state for parsing next tool
    pub fn reset(&mut self) {
        self.partial_tool = None;
        self.phase = ParsePhase::Searching;
        self.bracket_depth = 0;
        self.in_string = false;
        self.escape_next = false;
65
        self.harmony_stream = None;
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    }

    /// Process a single character for JSON parsing
    pub fn process_char(&mut self, ch: char) {
        // Handle escape sequences
        if self.escape_next {
            self.escape_next = false;
            self.buffer.push(ch);
            return;
        }

        if ch == '\\' && self.in_string {
            self.escape_next = true;
            self.buffer.push(ch);
            return;
        }

        // Track string boundaries
        if ch == '"' && !self.escape_next {
            self.in_string = !self.in_string;
        }

        // Track bracket depth for JSON
        if !self.in_string {
            match ch {
                '{' | '[' => {
                    self.bracket_depth += 1;
                }
                '}' | ']' => {
                    self.bracket_depth -= 1;
                    if self.bracket_depth == 0 && self.partial_tool.is_some() {
                        // Complete tool call found
                        self.phase = ParsePhase::Complete;
                    }
                }
                _ => {}
            }
        }

        self.buffer.push(ch);
    }

    /// Check if we have a complete JSON object/array
    pub fn has_complete_json(&self) -> bool {
        self.bracket_depth == 0 && !self.in_string && !self.buffer.is_empty()
    }

    /// Extract content from buffer starting at position
    pub fn extract_from(&self, start: usize) -> &str {
        if start >= self.buffer.len() {
            return "";
        }

        // Find the nearest character boundary at or after start
        let mut safe_start = start;
        while safe_start < self.buffer.len() && !self.buffer.is_char_boundary(safe_start) {
            safe_start += 1;
        }

        if safe_start < self.buffer.len() {
            &self.buffer[safe_start..]
        } else {
            ""
        }
    }

    /// Mark content as consumed up to position
    pub fn consume_to(&mut self, position: usize) {
        if position > self.consumed {
            self.consumed = position;
        }
    }

    /// Get unconsumed content
    pub fn unconsumed(&self) -> &str {
        if self.consumed >= self.buffer.len() {
            return "";
        }

        // Find the nearest character boundary at or after consumed
        let mut safe_consumed = self.consumed;
        while safe_consumed < self.buffer.len() && !self.buffer.is_char_boundary(safe_consumed) {
            safe_consumed += 1;
        }

        if safe_consumed < self.buffer.len() {
            &self.buffer[safe_consumed..]
        } else {
            ""
        }
    }

    /// Clear consumed content from buffer
    pub fn clear_consumed(&mut self) {
        if self.consumed > 0 {
            // Find the nearest character boundary at or before consumed
            let mut safe_consumed = self.consumed;
            while safe_consumed > 0 && !self.buffer.is_char_boundary(safe_consumed) {
                safe_consumed -= 1;
            }

            if safe_consumed > 0 {
                self.buffer.drain(..safe_consumed);
                self.consumed = self.consumed.saturating_sub(safe_consumed);
            }
        }
    }

    /// Add completed tool
    pub fn add_completed_tool(&mut self, tool: ToolCall) {
        self.completed_tools.push(tool);
        self.tool_index += 1;
    }
}

impl Default for ParseState {
    fn default() -> Self {
        Self::new()
    }
}
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202

/// Placeholder for Harmony streaming metadata captured during token-aware parsing.
#[derive(Debug, Clone, Default)]
pub struct HarmonyStreamState {
    /// All tokens observed so far for the current assistant response.
    pub tokens: Vec<u32>,
    /// Number of tokens that have already been processed by the Harmony parser.
    pub processed_tokens: usize,
    /// Number of tool calls emitted downstream.
    pub emitted_calls: usize,
    /// Pending analysis-channel content awaiting flush into normal text output.
    pub analysis_buffer: String,
    /// Whether the tool name has been surfaced for the current call.
    pub emitted_name: bool,
    /// Whether arguments have been surfaced for the current call.
    pub emitted_args: bool,
}