granite_reasoning_parser.py 14.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5

from collections.abc import Sequence

6
import regex as re
7
8
from transformers import PreTrainedTokenizerBase

9
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
10
from vllm.logger import init_logger
11
from vllm.reasoning import ReasoningParser
12
13
14
15
16
17
18
19
20
21
22
23

logger = init_logger(__name__)


class GraniteReasoningParser(ReasoningParser):
    """
    Reasoning parser for IBM Granite.

    IBM granite models currently use "Here is my thought process:"
    and "Here is my response:" to separate its thinking / response outputs.
    """

24
25
    def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
        super().__init__(tokenizer, *args, **kwargs)
26
27
28
29
30
31
32
33

        # NOTE: There have been some observed occurrences of quantized
        # instances of the current models using "Here's" instead of "Here is",
        # so to be safe, we match on both.
        self.think_start_expr = r"(?:Here's|Here is) my thought process:"
        self.response_start_expr = r"(?:Here's|Here is) my response:"

        self.reasoning_regex = re.compile(
34
35
            rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", re.DOTALL
        )
36
37

        self.valid_think_starts = [
38
39
            "Here's my thought process:",
            "Here is my thought process:",
40
        ]
41
        self.valid_response_starts = ["Here's my response:", "Here is my response:"]
42
43
44
45
46
47
48

        # Substrings to match for sequence boundaries on raw text
        self.seq_boundary_end = ":"
        self.seq_boundary_start = "Here"

        # The longest any thinking / start of response message can be
        self.longest_think_start = max(
49
50
            len(think_start) for think_start in self.valid_think_starts
        )
51

52
    def extract_reasoning(
53
        self, model_output: str, request: ChatCompletionRequest
54
    ) -> tuple[str | None, str | None]:
55
56
57
58
59
60
        """Extract the reasoning content & content sections, respectively.
        If the sequence doesn't match what we expect, i.e., the model generates
        something else, all content is considered non-reasoning content.

        Args:
            model_output (str): Output of the model to be parsed.
61
            request (ChatCompletionRequest): Request being processed.
62
63
64
65
66
67
68
69

        Returns:
            tuple[Optional[str], Optional[str]]: Tuple pair containing the
            reasoning content and non-reasoning content.
        """
        re_match = self.reasoning_regex.findall(model_output)
        if not re_match:
            return None, model_output
70
        reasoning, response_content = re_match[0]
71
        if not response_content:
72
73
            return reasoning, None
        return reasoning, response_content
74

75
    def extract_reasoning_streaming(
76
77
78
79
80
81
82
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
83
    ) -> DeltaMessage | None:
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        """Extract the reasoning content / content emitted by granite models;
        If the sequence doesn't match what we expect, i.e., the model generates
        something else, all content is considered non-reasoning content.

        NOTE: Granite models do not use a special token to start their reasoning
        and response sections; instead they have token sequences, e.g.,

                Here is my thought process: Foo Here is my response: Bar

        This increases the complexity of correctly handling streams, since we
        need to watch for specific sequences and correctly parse them without
        dropping content that is potentially overlapping & spanning multiple
        delta messages.

        Args:
            previous_text (str): Previous text outside of this delta message.
            current_text (str): Previous text + delta text.
            delta_text (str): Text to consider and parse content from.
            previous_token_ids (Sequence[int]): Token IDs of previous_text.
            current_token_ids (Sequence[int]): Token IDs of current_text.
            delta_token_ids (Sequence[int]): Token IDs of delta_text.

        Returns:
            Union[DeltaMessage, None]
                DeltaMessage with either reasoning content or content, or None.
        """
110
        reasoning, resp_seq_len, content = self._get_content_sections(current_text)
111
112
        # Either we haven't finished the start of the reasoning sequence,
        # or the model is generating something unexpected.
113
        if not reasoning:
114
            delta_message = self._get_delta_message_with_no_reasoning_bounds(
115
116
                current_text, delta_text
            )
117
118
119
120
        # We have a start of reasoning message, but have not yet finished
        # the start of response sequence.
        elif not content:
            delta_message = self._get_delta_message_with_no_response_bounds(
121
                current_text, reasoning, delta_text
122
            )
123
124
125
126
127
        # We've finished both the start of reasoning and start of response seq.
        else:
            # This should never happen since we matched on the response
            assert resp_seq_len is not None
            delta_message = self._get_delta_message_with_both_bounds(
128
                delta_text, reasoning, content, current_text, resp_seq_len
129
            )
130
        if not delta_message.content and not delta_message.reasoning:
131
132
133
134
135
136
137
138
139
            return None
        return delta_message

    #### Implementation details of stream parsing for granite models
    def _is_reasoning_start_substr(self, text: str) -> bool:
        """Check if a text matches one of the possible start reasoning seqs.

        Args:
            text (str): Text to check for leading substr.
140

141
142
143
144
        Returns:
            bool: True if any of the possible reasoning start seqs match.
        """
        return any(
145
146
            think_start.startswith(text) for think_start in self.valid_think_starts
        )
147
148
149
150
151
152

    def _is_response_start_substr(self, text: str) -> bool:
        """Check if a text matches one of the possible start response seqs.

        Args:
            text (str): Text to check for leading substr.
153

154
155
156
157
158
        Returns:
            bool: True if any of the possible response start seqs match.
        """
        return any(
            response_start.startswith(text)
159
160
            for response_start in self.valid_response_starts
        )
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

    def _get_delta_message_with_no_reasoning_bounds(
        self,
        current_text: str,
        delta_text: str,
    ) -> DeltaMessage:
        """Parse the delta message when the current text has not yet completed
        its start of reasoning sequence.

        Args:
            current_text (str): The full previous + delta text.
            delta_text (str): Text to consider and parse content from.

        Returns:
            DeltaMessage: Message containing the parsed content.
        """
        prev_longest_length = len(current_text) - len(delta_text)
        is_substr = self._is_reasoning_start_substr(current_text)
179
        was_substr = self._is_reasoning_start_substr(current_text[:prev_longest_length])
180
181
182
183
184
185

        # Check if we just generated something NOT in the special token seq;
        # if so, add everything that we previously skipped with this delta
        # message and append everything to content in the future.
        if was_substr and not is_substr:
            return DeltaMessage(
186
                reasoning=None,
187
188
189
190
                content=current_text,
            )
        if is_substr:
            # Might still be in the special token sequence; return nothing
191
            return DeltaMessage(reasoning=None, content=None)
192
193
        # Otherwise the sequence has already been broken and we already
        # corrected; just return the delta text as normal content.
194
        return DeltaMessage(reasoning=None, content=delta_text)
195
196
197
198

    def _get_delta_message_with_no_response_bounds(
        self,
        current_text: str,
199
        reasoning: str,
200
201
202
203
204
205
206
207
208
        delta_text: str,
    ) -> DeltaMessage:
        """Parse the delta message when the current text has both reasoning
        content with no (response) content. NOTE that we may have overlapping
        tokens with the start of reasoning / start of response sequences on
        either side of the delta text.

        Args:
            current_text (str): The full previous + delta text.
209
            reasoning (str): reasoning content from current_text.
210
211
212
213
214
215
216
217
218
219
220
            delta_text (str): Text to consider and parse content from.

        Returns:
            DeltaMessage: Message containing the parsed content.
        """
        # If we have no reasoning content or explicitly end with the start of
        # response sequence, we are in transition to the response; need to be
        # careful here, since the final token (:) will match the reasoning
        # content and fully parse it out; we should not pass the : back.
        ends_with_start_response_seq = any(
            current_text.endswith(response_start)
221
222
            for response_start in self.valid_response_starts
        )
223
224
        if reasoning is None or ends_with_start_response_seq:
            return DeltaMessage(reasoning=None, content=None)
225
226

        # Consider previous / current text only within context of the reasoning
227
228
        previous_text = reasoning[: -len(delta_text)]
        current_text = reasoning
229
230
231
232
233
234
235

        # We need to be careful about adding unfinished response sequences;
        # Find the place at which we MIGHT be starting a response sequence
        prev_idx = previous_text.rfind(self.seq_boundary_start)
        delta_idx = delta_text.rfind(self.seq_boundary_start)

        # Check the state of potential start of response substring matches.
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
        prev_was_substr = (
            self._is_response_start_substr(previous_text[prev_idx:])
            if prev_idx >= 0
            else False
        )
        delta_continues_substr = (
            self._is_response_start_substr(current_text[prev_idx:])
            if prev_idx >= 0
            else False
        )
        delta_new_substr = (
            self._is_response_start_substr(delta_text[delta_idx:])
            if delta_idx >= 0
            else False
        )
251
252
253

        # Delta only contains potential continued response sequence text.
        if delta_continues_substr:
254
            return DeltaMessage(reasoning=None, content=None)
255
256
257
258

        if not prev_was_substr:
            # Delta may be starting a new response seq but has other text too.
            if delta_new_substr:
259
                return DeltaMessage(reasoning=delta_text[:delta_idx], content=None)
260
            # Normal case for most reasoning text (no potential special seqs).
261
            return DeltaMessage(reasoning=delta_text, content=None)
262
263
264
265
        # The substring that previously seemed to be a potential response
        # seq wasn't one; we need to add the content to the delta message,
        # and also slice off the potential response sequence
        elif delta_new_substr:
266
267
            reasoning = previous_text[prev_idx:] + delta_text[:delta_idx]
            return DeltaMessage(reasoning=reasoning, content=None)
268
269
        # No new substring yet, and we broke our old one; take the whole delta
        return DeltaMessage(
270
            reasoning=previous_text[prev_idx:] + delta_text,
271
272
273
274
275
276
            content=None,
        )

    def _get_delta_message_with_both_bounds(
        self,
        delta_text: str,
277
        reasoning: str,
278
279
280
281
282
283
284
285
        response_content: str,
        current_text: str,
        response_seq_len: int,
    ) -> DeltaMessage:
        """Parse the delta message when the current text has both reasoning
        content and normal (response) content.

        Args:
286
            delta_text: Text to consider and parse content from.
287
            reasoning: reasoning content from current_text.
288
289
290
            response_content: response content from current_text.
            current_text: The full previous + delta text.
            response_seq_len: Len of the complete response sequence used.
291
292
293
294
295

        Returns:
            DeltaMessage: Message containing the parsed content.
        """
        # Always have content; take length to the end
296
297
        delta_content = delta_text[-len(response_content) :]
        reasoning_end_idx = len(delta_text) - (len(response_content) + response_seq_len)
298
299

        if reasoning_end_idx < 0:
300
            delta_reasoning = None
301
302
        else:
            # Get the starting offset
303
304
            start_reasoning_idx = (
                len(reasoning) + response_seq_len + len(response_content) - 1
305
            )
306
            delta_offset = len(current_text) - len(delta_text)
307
            start_offset = start_reasoning_idx - delta_offset
308
309
            if start_offset < 0:
                start_offset = 0
310
            delta_reasoning = delta_text[start_offset:reasoning_end_idx]
311
312

        return DeltaMessage(
313
            reasoning=delta_reasoning,
314
315
316
317
318
            content=delta_content,
        )

    def _get_content_sections(
        self, current_text: str
319
    ) -> tuple[str | None, int | None, str | None]:
320
321
322
323
324
325
326
327
328
329
330
331
        """Parse the text to extract the reasoning content / content
        if we have them.

        Args:
            current_text (str): The full previous + delta text.

        Returns:
            tuple[Optional[str], Optional[int], Optional[str]]: Tuple of len 3
            containing the reasoning content, the length of the response seq
            (if there is one) and the non-reasoning content.
        """
        current_chunk_start = 0
332
        start_reasoning = None
333
334
        parsed_content = False
        delimiter_idxs = [
335
336
            idx
            for idx, char in enumerate(current_text)
337
338
339
340
341
342
            if char == self.seq_boundary_end
        ]

        for current_chunk_end in delimiter_idxs:
            current_chunk = current_text[current_chunk_start:current_chunk_end]
            # Check to see if the start of reasoning seq if complete
343
            if start_reasoning is None:
344
345
                for think_start in self.valid_think_starts:
                    if current_chunk == think_start[:-1]:
346
                        start_reasoning = current_chunk_end + 1
347
348
349
350
351
352
                        current_chunk_start = current_chunk_end + 1
                        break

            # Check to see if the start of response seq if complete
            elif not parsed_content:
                for response_start in self.valid_response_starts:
353
                    if current_chunk[-len(response_start) + 1 :] == response_start[:-1]:
354
355
                        # Mark end of reasoning and start response content
                        # after the start of response sequence.
356
357
                        end_reasoning = current_chunk_end - len(response_start)
                        reasoning = current_text[start_reasoning:end_reasoning]
358
                        response_content = current_text[current_chunk_end + 1 :]
359
                        return reasoning, len(response_start), response_content
360

361
362
        if start_reasoning and not parsed_content:
            return current_text[start_reasoning:], None, None
363
        return None, None, None