granite_reasoning_parser.py 14.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5

from collections.abc import Sequence

6
import regex as re
7
8
from transformers import PreTrainedTokenizerBase

9
10
11
12
from vllm.entrypoints.openai.chat_completion.protocol import (
    ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
13
from vllm.logger import init_logger
14
from vllm.reasoning import ReasoningParser
15
16
17
18
19
20
21
22
23
24
25
26

logger = init_logger(__name__)


class GraniteReasoningParser(ReasoningParser):
    """
    Reasoning parser for IBM Granite.

    IBM granite models currently use "Here is my thought process:"
    and "Here is my response:" to separate its thinking / response outputs.
    """

27
28
    def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
        super().__init__(tokenizer, *args, **kwargs)
29
30
31
32
33
34
35
36

        # NOTE: There have been some observed occurrences of quantized
        # instances of the current models using "Here's" instead of "Here is",
        # so to be safe, we match on both.
        self.think_start_expr = r"(?:Here's|Here is) my thought process:"
        self.response_start_expr = r"(?:Here's|Here is) my response:"

        self.reasoning_regex = re.compile(
37
38
            rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", re.DOTALL
        )
39
40

        self.valid_think_starts = [
41
42
            "Here's my thought process:",
            "Here is my thought process:",
43
        ]
44
        self.valid_response_starts = ["Here's my response:", "Here is my response:"]
45
46
47
48
49
50
51

        # Substrings to match for sequence boundaries on raw text
        self.seq_boundary_end = ":"
        self.seq_boundary_start = "Here"

        # The longest any thinking / start of response message can be
        self.longest_think_start = max(
52
53
            len(think_start) for think_start in self.valid_think_starts
        )
54

55
    def extract_reasoning(
56
        self, model_output: str, request: ChatCompletionRequest
57
    ) -> tuple[str | None, str | None]:
58
59
60
61
62
63
        """Extract the reasoning content & content sections, respectively.
        If the sequence doesn't match what we expect, i.e., the model generates
        something else, all content is considered non-reasoning content.

        Args:
            model_output (str): Output of the model to be parsed.
64
            request (ChatCompletionRequest): Request being processed.
65
66
67
68
69
70
71
72

        Returns:
            tuple[Optional[str], Optional[str]]: Tuple pair containing the
            reasoning content and non-reasoning content.
        """
        re_match = self.reasoning_regex.findall(model_output)
        if not re_match:
            return None, model_output
73
        reasoning, response_content = re_match[0]
74
        if not response_content:
75
76
            return reasoning, None
        return reasoning, response_content
77

78
    def extract_reasoning_streaming(
79
80
81
82
83
84
85
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
86
    ) -> DeltaMessage | None:
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        """Extract the reasoning content / content emitted by granite models;
        If the sequence doesn't match what we expect, i.e., the model generates
        something else, all content is considered non-reasoning content.

        NOTE: Granite models do not use a special token to start their reasoning
        and response sections; instead they have token sequences, e.g.,

                Here is my thought process: Foo Here is my response: Bar

        This increases the complexity of correctly handling streams, since we
        need to watch for specific sequences and correctly parse them without
        dropping content that is potentially overlapping & spanning multiple
        delta messages.

        Args:
            previous_text (str): Previous text outside of this delta message.
            current_text (str): Previous text + delta text.
            delta_text (str): Text to consider and parse content from.
            previous_token_ids (Sequence[int]): Token IDs of previous_text.
            current_token_ids (Sequence[int]): Token IDs of current_text.
            delta_token_ids (Sequence[int]): Token IDs of delta_text.

        Returns:
            Union[DeltaMessage, None]
                DeltaMessage with either reasoning content or content, or None.
        """
113
        reasoning, resp_seq_len, content = self._get_content_sections(current_text)
114
115
        # Either we haven't finished the start of the reasoning sequence,
        # or the model is generating something unexpected.
116
        if not reasoning:
117
            delta_message = self._get_delta_message_with_no_reasoning_bounds(
118
119
                current_text, delta_text
            )
120
121
122
123
        # We have a start of reasoning message, but have not yet finished
        # the start of response sequence.
        elif not content:
            delta_message = self._get_delta_message_with_no_response_bounds(
124
                current_text, reasoning, delta_text
125
            )
126
127
128
129
130
        # We've finished both the start of reasoning and start of response seq.
        else:
            # This should never happen since we matched on the response
            assert resp_seq_len is not None
            delta_message = self._get_delta_message_with_both_bounds(
131
                delta_text, reasoning, content, current_text, resp_seq_len
132
            )
133
        if not delta_message.content and not delta_message.reasoning:
134
135
136
137
138
139
140
141
142
            return None
        return delta_message

    #### Implementation details of stream parsing for granite models
    def _is_reasoning_start_substr(self, text: str) -> bool:
        """Check if a text matches one of the possible start reasoning seqs.

        Args:
            text (str): Text to check for leading substr.
143

144
145
146
147
        Returns:
            bool: True if any of the possible reasoning start seqs match.
        """
        return any(
148
149
            think_start.startswith(text) for think_start in self.valid_think_starts
        )
150
151
152
153
154
155

    def _is_response_start_substr(self, text: str) -> bool:
        """Check if a text matches one of the possible start response seqs.

        Args:
            text (str): Text to check for leading substr.
156

157
158
159
160
161
        Returns:
            bool: True if any of the possible response start seqs match.
        """
        return any(
            response_start.startswith(text)
162
163
            for response_start in self.valid_response_starts
        )
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181

    def _get_delta_message_with_no_reasoning_bounds(
        self,
        current_text: str,
        delta_text: str,
    ) -> DeltaMessage:
        """Parse the delta message when the current text has not yet completed
        its start of reasoning sequence.

        Args:
            current_text (str): The full previous + delta text.
            delta_text (str): Text to consider and parse content from.

        Returns:
            DeltaMessage: Message containing the parsed content.
        """
        prev_longest_length = len(current_text) - len(delta_text)
        is_substr = self._is_reasoning_start_substr(current_text)
182
        was_substr = self._is_reasoning_start_substr(current_text[:prev_longest_length])
183
184
185
186
187
188

        # Check if we just generated something NOT in the special token seq;
        # if so, add everything that we previously skipped with this delta
        # message and append everything to content in the future.
        if was_substr and not is_substr:
            return DeltaMessage(
189
                reasoning=None,
190
191
192
193
                content=current_text,
            )
        if is_substr:
            # Might still be in the special token sequence; return nothing
194
            return DeltaMessage(reasoning=None, content=None)
195
196
        # Otherwise the sequence has already been broken and we already
        # corrected; just return the delta text as normal content.
197
        return DeltaMessage(reasoning=None, content=delta_text)
198
199
200
201

    def _get_delta_message_with_no_response_bounds(
        self,
        current_text: str,
202
        reasoning: str,
203
204
205
206
207
208
209
210
211
        delta_text: str,
    ) -> DeltaMessage:
        """Parse the delta message when the current text has both reasoning
        content with no (response) content. NOTE that we may have overlapping
        tokens with the start of reasoning / start of response sequences on
        either side of the delta text.

        Args:
            current_text (str): The full previous + delta text.
212
            reasoning (str): reasoning content from current_text.
213
214
215
216
217
218
219
220
221
222
223
            delta_text (str): Text to consider and parse content from.

        Returns:
            DeltaMessage: Message containing the parsed content.
        """
        # If we have no reasoning content or explicitly end with the start of
        # response sequence, we are in transition to the response; need to be
        # careful here, since the final token (:) will match the reasoning
        # content and fully parse it out; we should not pass the : back.
        ends_with_start_response_seq = any(
            current_text.endswith(response_start)
224
225
            for response_start in self.valid_response_starts
        )
226
227
        if reasoning is None or ends_with_start_response_seq:
            return DeltaMessage(reasoning=None, content=None)
228
229

        # Consider previous / current text only within context of the reasoning
230
231
        previous_text = reasoning[: -len(delta_text)]
        current_text = reasoning
232
233
234
235
236
237
238

        # We need to be careful about adding unfinished response sequences;
        # Find the place at which we MIGHT be starting a response sequence
        prev_idx = previous_text.rfind(self.seq_boundary_start)
        delta_idx = delta_text.rfind(self.seq_boundary_start)

        # Check the state of potential start of response substring matches.
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
        prev_was_substr = (
            self._is_response_start_substr(previous_text[prev_idx:])
            if prev_idx >= 0
            else False
        )
        delta_continues_substr = (
            self._is_response_start_substr(current_text[prev_idx:])
            if prev_idx >= 0
            else False
        )
        delta_new_substr = (
            self._is_response_start_substr(delta_text[delta_idx:])
            if delta_idx >= 0
            else False
        )
254
255
256

        # Delta only contains potential continued response sequence text.
        if delta_continues_substr:
257
            return DeltaMessage(reasoning=None, content=None)
258
259
260
261

        if not prev_was_substr:
            # Delta may be starting a new response seq but has other text too.
            if delta_new_substr:
262
                return DeltaMessage(reasoning=delta_text[:delta_idx], content=None)
263
            # Normal case for most reasoning text (no potential special seqs).
264
            return DeltaMessage(reasoning=delta_text, content=None)
265
266
267
268
        # The substring that previously seemed to be a potential response
        # seq wasn't one; we need to add the content to the delta message,
        # and also slice off the potential response sequence
        elif delta_new_substr:
269
270
            reasoning = previous_text[prev_idx:] + delta_text[:delta_idx]
            return DeltaMessage(reasoning=reasoning, content=None)
271
272
        # No new substring yet, and we broke our old one; take the whole delta
        return DeltaMessage(
273
            reasoning=previous_text[prev_idx:] + delta_text,
274
275
276
277
278
279
            content=None,
        )

    def _get_delta_message_with_both_bounds(
        self,
        delta_text: str,
280
        reasoning: str,
281
282
283
284
285
286
287
288
        response_content: str,
        current_text: str,
        response_seq_len: int,
    ) -> DeltaMessage:
        """Parse the delta message when the current text has both reasoning
        content and normal (response) content.

        Args:
289
            delta_text: Text to consider and parse content from.
290
            reasoning: reasoning content from current_text.
291
292
293
            response_content: response content from current_text.
            current_text: The full previous + delta text.
            response_seq_len: Len of the complete response sequence used.
294
295
296
297
298

        Returns:
            DeltaMessage: Message containing the parsed content.
        """
        # Always have content; take length to the end
299
300
        delta_content = delta_text[-len(response_content) :]
        reasoning_end_idx = len(delta_text) - (len(response_content) + response_seq_len)
301
302

        if reasoning_end_idx < 0:
303
            delta_reasoning = None
304
305
        else:
            # Get the starting offset
306
307
            start_reasoning_idx = (
                len(reasoning) + response_seq_len + len(response_content) - 1
308
            )
309
            delta_offset = len(current_text) - len(delta_text)
310
            start_offset = start_reasoning_idx - delta_offset
311
312
            if start_offset < 0:
                start_offset = 0
313
            delta_reasoning = delta_text[start_offset:reasoning_end_idx]
314
315

        return DeltaMessage(
316
            reasoning=delta_reasoning,
317
318
319
320
321
            content=delta_content,
        )

    def _get_content_sections(
        self, current_text: str
322
    ) -> tuple[str | None, int | None, str | None]:
323
324
325
326
327
328
329
330
331
332
333
334
        """Parse the text to extract the reasoning content / content
        if we have them.

        Args:
            current_text (str): The full previous + delta text.

        Returns:
            tuple[Optional[str], Optional[int], Optional[str]]: Tuple of len 3
            containing the reasoning content, the length of the response seq
            (if there is one) and the non-reasoning content.
        """
        current_chunk_start = 0
335
        start_reasoning = None
336
337
        parsed_content = False
        delimiter_idxs = [
338
339
            idx
            for idx, char in enumerate(current_text)
340
341
342
343
344
345
            if char == self.seq_boundary_end
        ]

        for current_chunk_end in delimiter_idxs:
            current_chunk = current_text[current_chunk_start:current_chunk_end]
            # Check to see if the start of reasoning seq if complete
346
            if start_reasoning is None:
347
348
                for think_start in self.valid_think_starts:
                    if current_chunk == think_start[:-1]:
349
                        start_reasoning = current_chunk_end + 1
350
351
352
353
354
355
                        current_chunk_start = current_chunk_end + 1
                        break

            # Check to see if the start of response seq if complete
            elif not parsed_content:
                for response_start in self.valid_response_starts:
356
                    if current_chunk[-len(response_start) + 1 :] == response_start[:-1]:
357
358
                        # Mark end of reasoning and start response content
                        # after the start of response sequence.
359
360
                        end_reasoning = current_chunk_end - len(response_start)
                        reasoning = current_text[start_reasoning:end_reasoning]
361
                        response_content = current_text[current_chunk_end + 1 :]
362
                        return reasoning, len(response_start), response_content
363

364
365
        if start_reasoning and not parsed_content:
            return current_text[start_reasoning:], None, None
366
        return None, None, None