basic_parsers.py 7.46 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from abc import abstractmethod
5
6
from collections.abc import Iterable, Sequence
from itertools import islice
7
from typing import TYPE_CHECKING
8

9
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
10
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
11
from vllm.tokenizers import TokenizerLike
12

13
if TYPE_CHECKING:
14
15
    from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
    from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
16

17
18
19
20

class BaseThinkingReasoningParser(ReasoningParser):
    """
    Base class for reasoning parsers that use thinking tokens.
21

22
23
24
    This class provides common functionality for parsers that use start and end
    tokens to delimit reasoning content (
        e.g., <think>...</think>, <seed:think>...</seed:think>).
25

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    Subclasses must implement the start and end tokens via abstract
    properties.
    """

    @property
    @abstractmethod
    def start_token(self) -> str:
        """The token that starts reasoning content."""
        raise NotImplementedError

    @property
    @abstractmethod
    def end_token(self) -> str:
        """The token that ends reasoning content."""
        raise NotImplementedError

42
    def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
43
        super().__init__(tokenizer, *args, **kwargs)
44
45
46
47

        if not self.model_tokenizer:
            raise ValueError(
                "The model tokenizer must be passed to the ReasoningParser "
48
49
                "constructor during construction."
            )
50
51

        if not self.start_token or not self.end_token:
52
            raise ValueError("start_token and end_token must be defined in subclasses")
53

54
55
56
        start_token_id = self.vocab.get(self.start_token)
        end_token_id = self.vocab.get(self.end_token)
        if start_token_id is None or end_token_id is None:
57
58
            raise RuntimeError(
                f"{self.__class__.__name__} reasoning parser could not locate "
59
60
                "think start/end tokens in the tokenizer!"
            )
61
62
        self.start_token_id: int = start_token_id
        self.end_token_id: int = end_token_id
63

64
    def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
65
        start_token_id = self.start_token_id
66
        end_token_id = self.end_token_id
67
68
69
70
71
72
73

        for i in range(len(input_ids) - 1, -1, -1):
            if input_ids[i] == start_token_id:
                return False
            if input_ids[i] == end_token_id:
                return True
        return False
74

75
    def is_reasoning_end_streaming(
76
        self, input_ids: Sequence[int], delta_ids: Iterable[int]
77
78
79
80
    ) -> bool:
        end_token_id = self.end_token_id
        return end_token_id in delta_ids

81
82
83
84
    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
        """
        Extract the content after the end tokens
        """
85
        if self.end_token_id not in islice(input_ids, 0, max(0, len(input_ids) - 1)):
86
87
            return []
        else:
88
            return input_ids[input_ids.index(self.end_token_id) + 1 :]
89

90
    def extract_reasoning_streaming(
91
92
93
94
95
96
97
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
98
    ) -> DeltaMessage | None:
99
100
101
102
103
104
        """
        Extract reasoning content from a delta message.
        Handles streaming output where previous + delta = current.
        Uses token IDs for faster processing.
        """
        # Skip single special tokens
105
106
107
        if len(delta_token_ids) == 1 and (
            delta_token_ids[0] in [self.start_token_id, self.end_token_id]
        ):
108
109
110
111
112
113
114
115
116
            return None

        # Check if start token is present in previous or delta.
        # Keep compatibility with models that don't generate start tokens.
        if self.start_token_id in previous_token_ids:
            if self.end_token_id in delta_token_ids:
                # start token in previous, end token in delta,
                # extract reasoning content
                end_index = delta_text.find(self.end_token)
117
                reasoning = delta_text[:end_index]
118
                content = delta_text[end_index + len(self.end_token) :]
119
                return DeltaMessage(
120
                    reasoning=reasoning, content=content if content else None
121
122
123
124
125
126
127
128
                )
            elif self.end_token_id in previous_token_ids:
                # start token in previous, end token in previous,
                # reasoning content continues
                return DeltaMessage(content=delta_text)
            else:
                # start token in previous, no end token in previous or delta,
                # reasoning content continues
129
                return DeltaMessage(reasoning=delta_text)
130
131
132
133
134
135
        elif self.start_token_id in delta_token_ids:
            if self.end_token_id in delta_token_ids:
                # start token in delta, end token in delta,
                # extract reasoning content
                start_index = delta_text.find(self.start_token)
                end_index = delta_text.find(self.end_token)
136
                reasoning = delta_text[start_index + len(self.start_token) : end_index]
137
                content = delta_text[end_index + len(self.end_token) :]
138
                return DeltaMessage(
139
                    reasoning=reasoning, content=content if content else None
140
141
142
143
                )
            else:
                # start token in delta, no end token in delta,
                # reasoning content continues
144
                return DeltaMessage(reasoning=delta_text)
145
146
147
148
        else:
            # not find thinking start token
            return DeltaMessage(content=delta_text)

149
    def extract_reasoning(
150
        self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
151
    ) -> tuple[str | None, str | None]:
152
153
        """
        Extract reasoning content from the model output.
154

155
156
157
158
159
160
        This is the base implementation that works for most models.
        Subclasses can override this method for specific behavior.
        """
        # Check if the start token is present in the model output, remove it
        # if it is present.
        model_output_parts = model_output.partition(self.start_token)
161
162
163
        model_output = (
            model_output_parts[2] if model_output_parts[1] else model_output_parts[0]
        )
164
165
166
167
168
169

        # For models that may not generate start token,
        # assume the reasoning content is always at the start.
        if self.end_token not in model_output:
            return model_output, None
        else:
170
            reasoning, _, content = model_output.partition(self.end_token)
171
172
            # If generation stops right after end-of-think, return null content
            final_content = content or None
173
            return reasoning, final_content
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193

    def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
        """Count tokens that fall within start/end thinking markers.

        Uses a depth counter so nested spans are handled safely and stray end
        tokens do not drive the counter negative.
        """
        count = 0
        depth = 0
        for token_id in token_ids:
            if token_id == self.start_token_id:
                depth += 1
                continue
            if token_id == self.end_token_id:
                if depth > 0:
                    depth -= 1
                continue
            if depth > 0:
                count += 1
        return count