olmo3_reasoning_parser.py 10.7 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import dataclasses as dt
import enum
from collections.abc import Sequence
7
from typing import TYPE_CHECKING
8
9
10
11

import regex as re

if TYPE_CHECKING:
12
    from vllm.tokenizers import TokenizerLike
13

14
15
16
17
18
from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    DeltaMessage,
    ResponsesRequest,
)
19
from vllm.logger import init_logger
20
from vllm.reasoning import ReasoningParser
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

logger = init_logger(__name__)


class Olmo3ReasoningState(enum.Enum):
    REASONING = 1
    CONTENT = 2


@dt.dataclass(frozen=True)
class Indices:
    start: int
    end: int

    def __len__(self):
        return self.end - self.start


39
def string_overlap(a: str, b: str) -> tuple[Indices | None, Indices | None]:
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    """
    Find the longest overlap where the end of string a matches the start
    of string b.

    Args:
        a: First string
        b: Second string

    Returns:
        Tuple of IndicesTuples representing the overlapping portions in each
        string, or a tuple of None if no overlap exists
    """

    # swap so a is always the shorter string
    a, b, swap = (a, b, False) if len(a) < len(b) else (b, a, True)

    # first check: is a fully contained in b?
    if a in b:
        ind_a = Indices(0, len(a))
        ind_b = Indices(b.index(a), b.index(a) + len(a))
        return (ind_b, ind_a) if swap else (ind_a, ind_b)

    # second check: does the end of a overlap with the
    #               beginning of b?
    for i in range(len(a) - 1, 0, -1):
        if a[-i:] == b[:i]:
            ind_a = Indices(len(a) - i, len(a))
            ind_b = Indices(0, i)
            return (ind_b, ind_a) if swap else (ind_a, ind_b)

    # third check: does the beginning of a overlap with
    #              the end of b?
    for i in range(len(a) - 1, 0, -1):
        if b[-i:] == a[:i]:
            ind_a = Indices(0, i)
            ind_b = Indices(len(b) - i, len(b))
            return (ind_b, ind_a) if swap else (ind_a, ind_b)

    return None, None


@dt.dataclass
class Olmo3ReasoningBuffer:
    think_start: str = "<think>"
    think_end: str = "</think>"
    buffer: str = ""

    # we start in reasoning state to support cases where we hardcode
    # <think> as the start of the reasoning block.
    # In those cases, the only token we will see is </think>, which
    # is when we switch to content state.
    state: Olmo3ReasoningState = Olmo3ReasoningState.REASONING

93
    def process_buffer(self) -> DeltaMessage | None:
94
95
96
97
98
99
        start_think_idx = self.buffer.find(self.think_start)

        if start_think_idx >= 0:
            self.state = Olmo3ReasoningState.REASONING
            pretext, self.buffer = (
                self.buffer[:start_think_idx],
100
                self.buffer[start_think_idx + len(self.think_start) :],
101
102
103
104
105
106
107
108
109
110
111
112
            )
            if start_think_idx > 0:
                # this covers the case there's content before
                # the start of the reasoning block
                return DeltaMessage(content=pretext)

        end_think_idx = self.buffer.rfind(self.think_end)

        if end_think_idx >= 0:
            self.state = Olmo3ReasoningState.CONTENT
            pretext, self.buffer = (
                self.buffer[:end_think_idx],
113
                self.buffer[end_think_idx + len(self.think_end) :],
114
115
116
117
            )
            if end_think_idx > 0:
                # this covers the case there's content before
                # the end of the reasoning block
118
                return DeltaMessage(reasoning=pretext)
119
120
121
122
123
124
125
126

        if self.state == Olmo3ReasoningState.REASONING:
            # we are inside reasoning block, return and empty
            # the text buffer
            (
                text_buffer,
                self.buffer,
            ) = self.buffer, ""
127
            return DeltaMessage(reasoning=text_buffer)
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

        if self.state == Olmo3ReasoningState.CONTENT:
            # we are outside reasoning block, return and empty
            # the text buffer
            (
                text_buffer,
                self.buffer,
            ) = self.buffer, ""
            return DeltaMessage(content=text_buffer)

        # nothing to return unless we are in reasoning or content state
        return None

    def __len__(self):
        # is the length of the text buffer
        return len(self.buffer)

145
    def add_text(self, delta_text: str) -> DeltaMessage | None:
146
147
148
149
        # we start by adding the delta text to the buffer
        self.buffer += delta_text

        # setting this to empty before starting
150
        delta_message: DeltaMessage | None = None
151
152
153
154
155
156
157

        # we start by computing the overlap between the delta_text
        # and start/end of think tokens.
        _, overlap_think_start = string_overlap(delta_text, self.think_start)
        _, overlap_think_end = string_overlap(delta_text, self.think_end)

        partial_overlap_start = overlap_think_start is not None and len(
158
159
            overlap_think_start
        ) < len(self.think_start)
160
        partial_overlap_end = overlap_think_end is not None and len(
161
162
163
164
165
166
167
168
            overlap_think_end
        ) < len(self.think_end)

        if (
            partial_overlap_start
            and self.think_start in self.buffer
            and not partial_overlap_end
        ):
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
            # we can only process the buffer if partial overlap
            # is the last part of think token (thus causing
            # text_buffer to contain the start of think token)
            # and there are no partial overlaps with end think
            delta_message = self.process_buffer()

        elif partial_overlap_end and self.think_end in self.buffer:
            # same as before (partial overlap only allowed)
            # if the buffer contains the end think token,
            # but we don't have to check for partial overlap
            # with start think token because they are handled
            # by the previous condition
            delta_message = self.process_buffer()

        elif partial_overlap_start or partial_overlap_end:
            # in general, if there are overlaps, we don't
            # process the buffer because we want to wait until
            # the think token is fully completed.
            return None
        else:
            # we process the buffer as normal
            delta_message = self.process_buffer()

        return delta_message


class Olmo3ReasoningParser(ReasoningParser):
    """
    Reasoning parser for Olmo 3 model

    Olmo3ReasoningParser

    This class implements a reasoning parser specifically designed for the
    Olmo 3 family of models. Olmo 3 models do not use special tokens to
    indicate reasoning; rather, reasoning trace is wrapped in `<think>` and
    `</think>`, which are tokenized using standard vocabulary entries.
    Because of this, the parser operates in string space, accumulating the
    characters in a buffer until it sees `<think>` or `</think>`. tokens
    to switch modes.

    Key Features:
        - For non-stream output, Recognizes and extracts reasoning (text
          bracketed by `<think>` and `</think>`) and content (everything
          after the first `</think>`).
        - For stream process, it uses a buffer to accumulate delta text,
          and output progressive delta messages as soon as thinking starts
          or ends.
        - For reliability, some Olmo 3 models may hardcode the first
          `<think>` token is the input text (similar to Deepseek R1,
          or reasoning-only Qwen models). To support such variants, the
          parser can optionally work in cases where the first `<think>`
          token is missing from generation.
    """

223
    def __init__(self, tokenizer: "TokenizerLike", *args, **kwargs):
224
225
226
227
228
229
230
231
        super().__init__(tokenizer, *args, **kwargs)

        self.think_start = r"<think>"
        self.think_end = r"</think>"

        # notice that the first think is optional; this allows template to
        # work in cases when we hardcode a <think> at the beginning of the
        # reasoning template.
232
233
234
235
        reasoning_expr = (
            rf"^(?:{self.think_start})?(?P<reasoning>.*?)"
            + rf"{self.think_end}(?P<content>.*)$"
        )
236
237
        self.reasoning_regex = re.compile(reasoning_expr, re.DOTALL)

238
239
240
        self.buffer = Olmo3ReasoningBuffer(
            think_start=self.think_start, think_end=self.think_end
        )
241
242
243
244
245
246
247
248
249
250
251
252

    def is_reasoning_end(self, input_ids: list[int]) -> bool:
        text = self.model_tokenizer.decode(input_ids)
        return self.think_end in text

    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
        # for Olmo 3 streaming reason parsing, the stream parse
        # will call first, and the same token will be called in
        # is_reasoning_end and extract_content_ids
        # this id is not part of content, so just return [] here.
        return []

253
    def extract_reasoning(
254
255
        self,
        model_output: str,
256
257
        request: ChatCompletionRequest | ResponsesRequest,
    ) -> tuple[str | None, str | None]:
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
        """Extract the reasoning content & content sections, respectively.
        If the sequence doesn't match what we expect, i.e., the model generates
        something else, all content is considered non-reasoning content.

        Args:
            model_output (str): Output of the model to be parsed.
            request (ChatCompletionRequest | ResponsesRequest): Request being
                processed.

        Returns:
            tuple[Optional[str], Optional[str]]: Tuple pair containing the
            reasoning content and non-reasoning content.
        """

        re_match = self.reasoning_regex.match(model_output)
        if re_match:
274
            reasoning = re_match.group("reasoning") or None
275
            content = re_match.group("content") or None
276
            return reasoning, content
277
278
279
280

        # no reasoning content
        return None, model_output

281
    def extract_reasoning_streaming(
282
283
284
285
286
287
288
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
289
    ) -> DeltaMessage | None:
290
291
292
        """Extract content using token ID sequence state machine"""

        delta_message = self.buffer.add_text(delta_text)
293
        if delta_message is None and self.buffer.think_end in self.buffer.buffer:
294
295
296
297
298
299
300
301
302
            # this is a bit hacky, but, because of how the buffer is
            # constructed, if the last delta_text contains characters that
            # marks the end of thinking tokens, then messages in the buffer
            # would never be processed because we get no other turn. To get
            # around that, we check if the text buffer contains the end of
            # thinking tokens, and, if so, we reprocess the buffer again.
            delta_message = self.buffer.process_buffer()

        return delta_message