granite_reasoning_parser.py 15.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5

from collections.abc import Sequence

6
import regex as re
7
8
from transformers import PreTrainedTokenizerBase

9
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
10
from vllm.logger import init_logger
11
from vllm.reasoning import ReasoningParser, ReasoningParserManager
12
13
14
15
16
17
18
19
20
21
22
23
24

logger = init_logger(__name__)


@ReasoningParserManager.register_module("granite")
class GraniteReasoningParser(ReasoningParser):
    """
    Reasoning parser for IBM Granite.

    IBM granite models currently use "Here is my thought process:"
    and "Here is my response:" to separate its thinking / response outputs.
    """

25
26
    def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
        super().__init__(tokenizer, *args, **kwargs)
27
28
29
30
31
32
33
34

        # NOTE: There have been some observed occurrences of quantized
        # instances of the current models using "Here's" instead of "Here is",
        # so to be safe, we match on both.
        self.think_start_expr = r"(?:Here's|Here is) my thought process:"
        self.response_start_expr = r"(?:Here's|Here is) my response:"

        self.reasoning_regex = re.compile(
35
36
            rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", re.DOTALL
        )
37
38

        self.valid_think_starts = [
39
40
            "Here's my thought process:",
            "Here is my thought process:",
41
        ]
42
        self.valid_response_starts = ["Here's my response:", "Here is my response:"]
43
44
45
46
47
48
49

        # Substrings to match for sequence boundaries on raw text
        self.seq_boundary_end = ":"
        self.seq_boundary_start = "Here"

        # The longest any thinking / start of response message can be
        self.longest_think_start = max(
50
51
            len(think_start) for think_start in self.valid_think_starts
        )
52
53

    def extract_reasoning_content(
54
        self, model_output: str, request: ChatCompletionRequest
55
    ) -> tuple[str | None, str | None]:
56
57
58
59
60
61
        """Extract the reasoning content & content sections, respectively.
        If the sequence doesn't match what we expect, i.e., the model generates
        something else, all content is considered non-reasoning content.

        Args:
            model_output (str): Output of the model to be parsed.
62
            request (ChatCompletionRequest): Request being processed.
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

        Returns:
            tuple[Optional[str], Optional[str]]: Tuple pair containing the
            reasoning content and non-reasoning content.
        """
        re_match = self.reasoning_regex.findall(model_output)
        if not re_match:
            return None, model_output
        reasoning_content, response_content = re_match[0]
        if not response_content:
            return reasoning_content, None
        return reasoning_content, response_content

    def extract_reasoning_content_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
84
    ) -> DeltaMessage | None:
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        """Extract the reasoning content / content emitted by granite models;
        If the sequence doesn't match what we expect, i.e., the model generates
        something else, all content is considered non-reasoning content.

        NOTE: Granite models do not use a special token to start their reasoning
        and response sections; instead they have token sequences, e.g.,

                Here is my thought process: Foo Here is my response: Bar

        This increases the complexity of correctly handling streams, since we
        need to watch for specific sequences and correctly parse them without
        dropping content that is potentially overlapping & spanning multiple
        delta messages.

        Args:
            previous_text (str): Previous text outside of this delta message.
            current_text (str): Previous text + delta text.
            delta_text (str): Text to consider and parse content from.
            previous_token_ids (Sequence[int]): Token IDs of previous_text.
            current_token_ids (Sequence[int]): Token IDs of current_text.
            delta_token_ids (Sequence[int]): Token IDs of delta_text.

        Returns:
            Union[DeltaMessage, None]
                DeltaMessage with either reasoning content or content, or None.
        """
        reasoning_content, resp_seq_len, content = self._get_content_sections(
112
113
            current_text
        )
114
115
116
117
        # Either we haven't finished the start of the reasoning sequence,
        # or the model is generating something unexpected.
        if not reasoning_content:
            delta_message = self._get_delta_message_with_no_reasoning_bounds(
118
119
                current_text, delta_text
            )
120
121
122
123
        # We have a start of reasoning message, but have not yet finished
        # the start of response sequence.
        elif not content:
            delta_message = self._get_delta_message_with_no_response_bounds(
124
125
                current_text, reasoning_content, delta_text
            )
126
127
128
129
130
        # We've finished both the start of reasoning and start of response seq.
        else:
            # This should never happen since we matched on the response
            assert resp_seq_len is not None
            delta_message = self._get_delta_message_with_both_bounds(
131
132
                delta_text, reasoning_content, content, current_text, resp_seq_len
            )
133
134
135
136
137
138
139
140
141
142
        if not delta_message.content and not delta_message.reasoning_content:
            return None
        return delta_message

    #### Implementation details of stream parsing for granite models
    def _is_reasoning_start_substr(self, text: str) -> bool:
        """Check if a text matches one of the possible start reasoning seqs.

        Args:
            text (str): Text to check for leading substr.
143

144
145
146
147
        Returns:
            bool: True if any of the possible reasoning start seqs match.
        """
        return any(
148
149
            think_start.startswith(text) for think_start in self.valid_think_starts
        )
150
151
152
153
154
155

    def _is_response_start_substr(self, text: str) -> bool:
        """Check if a text matches one of the possible start response seqs.

        Args:
            text (str): Text to check for leading substr.
156

157
158
159
160
161
        Returns:
            bool: True if any of the possible response start seqs match.
        """
        return any(
            response_start.startswith(text)
162
163
            for response_start in self.valid_response_starts
        )
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181

    def _get_delta_message_with_no_reasoning_bounds(
        self,
        current_text: str,
        delta_text: str,
    ) -> DeltaMessage:
        """Parse the delta message when the current text has not yet completed
        its start of reasoning sequence.

        Args:
            current_text (str): The full previous + delta text.
            delta_text (str): Text to consider and parse content from.

        Returns:
            DeltaMessage: Message containing the parsed content.
        """
        prev_longest_length = len(current_text) - len(delta_text)
        is_substr = self._is_reasoning_start_substr(current_text)
182
        was_substr = self._is_reasoning_start_substr(current_text[:prev_longest_length])
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223

        # Check if we just generated something NOT in the special token seq;
        # if so, add everything that we previously skipped with this delta
        # message and append everything to content in the future.
        if was_substr and not is_substr:
            return DeltaMessage(
                reasoning_content=None,
                content=current_text,
            )
        if is_substr:
            # Might still be in the special token sequence; return nothing
            return DeltaMessage(reasoning_content=None, content=None)
        # Otherwise the sequence has already been broken and we already
        # corrected; just return the delta text as normal content.
        return DeltaMessage(reasoning_content=None, content=delta_text)

    def _get_delta_message_with_no_response_bounds(
        self,
        current_text: str,
        reasoning_content: str,
        delta_text: str,
    ) -> DeltaMessage:
        """Parse the delta message when the current text has both reasoning
        content with no (response) content. NOTE that we may have overlapping
        tokens with the start of reasoning / start of response sequences on
        either side of the delta text.

        Args:
            current_text (str): The full previous + delta text.
            reasoning_content (str): reasoning content from current_text.
            delta_text (str): Text to consider and parse content from.

        Returns:
            DeltaMessage: Message containing the parsed content.
        """
        # If we have no reasoning content or explicitly end with the start of
        # response sequence, we are in transition to the response; need to be
        # careful here, since the final token (:) will match the reasoning
        # content and fully parse it out; we should not pass the : back.
        ends_with_start_response_seq = any(
            current_text.endswith(response_start)
224
225
            for response_start in self.valid_response_starts
        )
226
227
228
229
        if reasoning_content is None or ends_with_start_response_seq:
            return DeltaMessage(reasoning_content=None, content=None)

        # Consider previous / current text only within context of the reasoning
230
        previous_text = reasoning_content[: -len(delta_text)]
231
232
233
234
235
236
237
238
        current_text = reasoning_content

        # We need to be careful about adding unfinished response sequences;
        # Find the place at which we MIGHT be starting a response sequence
        prev_idx = previous_text.rfind(self.seq_boundary_start)
        delta_idx = delta_text.rfind(self.seq_boundary_start)

        # Check the state of potential start of response substring matches.
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
        prev_was_substr = (
            self._is_response_start_substr(previous_text[prev_idx:])
            if prev_idx >= 0
            else False
        )
        delta_continues_substr = (
            self._is_response_start_substr(current_text[prev_idx:])
            if prev_idx >= 0
            else False
        )
        delta_new_substr = (
            self._is_response_start_substr(delta_text[delta_idx:])
            if delta_idx >= 0
            else False
        )
254
255
256
257
258
259
260
261

        # Delta only contains potential continued response sequence text.
        if delta_continues_substr:
            return DeltaMessage(reasoning_content=None, content=None)

        if not prev_was_substr:
            # Delta may be starting a new response seq but has other text too.
            if delta_new_substr:
262
263
264
                return DeltaMessage(
                    reasoning_content=delta_text[:delta_idx], content=None
                )
265
266
267
268
269
270
            # Normal case for most reasoning text (no potential special seqs).
            return DeltaMessage(reasoning_content=delta_text, content=None)
        # The substring that previously seemed to be a potential response
        # seq wasn't one; we need to add the content to the delta message,
        # and also slice off the potential response sequence
        elif delta_new_substr:
271
272
            reasoning_content = previous_text[prev_idx:] + delta_text[:delta_idx]
            return DeltaMessage(reasoning_content=reasoning_content, content=None)
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
        # No new substring yet, and we broke our old one; take the whole delta
        return DeltaMessage(
            reasoning_content=previous_text[prev_idx:] + delta_text,
            content=None,
        )

    def _get_delta_message_with_both_bounds(
        self,
        delta_text: str,
        reasoning_content: str,
        response_content: str,
        current_text: str,
        response_seq_len: int,
    ) -> DeltaMessage:
        """Parse the delta message when the current text has both reasoning
        content and normal (response) content.

        Args:
291
292
293
294
295
            delta_text: Text to consider and parse content from.
            reasoning_content: reasoning content from current_text.
            response_content: response content from current_text.
            current_text: The full previous + delta text.
            response_seq_len: Len of the complete response sequence used.
296
297
298
299
300

        Returns:
            DeltaMessage: Message containing the parsed content.
        """
        # Always have content; take length to the end
301
302
        delta_content = delta_text[-len(response_content) :]
        reasoning_end_idx = len(delta_text) - (len(response_content) + response_seq_len)
303
304
305
306
307

        if reasoning_end_idx < 0:
            delta_reasoning_content = None
        else:
            # Get the starting offset
308
309
310
            start_reasoning_content_idx = (
                len(reasoning_content) + response_seq_len + len(response_content) - 1
            )
311
312
313
314
            delta_offset = len(current_text) - len(delta_text)
            start_offset = start_reasoning_content_idx - delta_offset
            if start_offset < 0:
                start_offset = 0
315
            delta_reasoning_content = delta_text[start_offset:reasoning_end_idx]
316
317
318
319
320
321
322
323

        return DeltaMessage(
            reasoning_content=delta_reasoning_content,
            content=delta_content,
        )

    def _get_content_sections(
        self, current_text: str
324
    ) -> tuple[str | None, int | None, str | None]:
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
        """Parse the text to extract the reasoning content / content
        if we have them.

        Args:
            current_text (str): The full previous + delta text.

        Returns:
            tuple[Optional[str], Optional[int], Optional[str]]: Tuple of len 3
            containing the reasoning content, the length of the response seq
            (if there is one) and the non-reasoning content.
        """
        current_chunk_start = 0
        start_reasoning_content = None
        parsed_content = False
        delimiter_idxs = [
340
341
            idx
            for idx, char in enumerate(current_text)
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
            if char == self.seq_boundary_end
        ]

        for current_chunk_end in delimiter_idxs:
            current_chunk = current_text[current_chunk_start:current_chunk_end]
            # Check to see if the start of reasoning seq if complete
            if start_reasoning_content is None:
                for think_start in self.valid_think_starts:
                    if current_chunk == think_start[:-1]:
                        start_reasoning_content = current_chunk_end + 1
                        current_chunk_start = current_chunk_end + 1
                        break

            # Check to see if the start of response seq if complete
            elif not parsed_content:
                for response_start in self.valid_response_starts:
358
                    if current_chunk[-len(response_start) + 1 :] == response_start[:-1]:
359
360
                        # Mark end of reasoning and start response content
                        # after the start of response sequence.
361
                        end_reasoning_content = current_chunk_end - len(response_start)
362
                        reasoning_content = current_text[
363
364
365
366
                            start_reasoning_content:end_reasoning_content
                        ]
                        response_content = current_text[current_chunk_end + 1 :]
                        return reasoning_content, len(response_start), response_content
367
368
369
370

        if start_reasoning_content and not parsed_content:
            return current_text[start_reasoning_content:], None, None
        return None, None, None