basic_parsers.py 6.76 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from abc import abstractmethod
from collections.abc import Sequence
6
from typing import TYPE_CHECKING, Any
7

8
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
9
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
10
from vllm.tokenizers import TokenizerLike
11

12
if TYPE_CHECKING:
13
    from vllm.entrypoints.openai.chat_completion.protocol import (
14
        ChatCompletionRequest,
15
    )
16
    from vllm.entrypoints.openai.responses.protocol import (
17
18
19
20
21
22
        ResponsesRequest,
    )
else:
    ChatCompletionRequest = Any
    ResponsesRequest = Any

23
24
25
26

class BaseThinkingReasoningParser(ReasoningParser):
    """
    Base class for reasoning parsers that use thinking tokens.
27

28
29
30
    This class provides common functionality for parsers that use start and end
    tokens to delimit reasoning content (
        e.g., <think>...</think>, <seed:think>...</seed:think>).
31

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    Subclasses must implement the start and end tokens via abstract
    properties.
    """

    @property
    @abstractmethod
    def start_token(self) -> str:
        """The token that starts reasoning content."""
        raise NotImplementedError

    @property
    @abstractmethod
    def end_token(self) -> str:
        """The token that ends reasoning content."""
        raise NotImplementedError

48
    def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
49
        super().__init__(tokenizer, *args, **kwargs)
50
51
52
53

        if not self.model_tokenizer:
            raise ValueError(
                "The model tokenizer must be passed to the ReasoningParser "
54
55
                "constructor during construction."
            )
56
57

        if not self.start_token or not self.end_token:
58
            raise ValueError("start_token and end_token must be defined in subclasses")
59
60
61
62
63
64

        self.start_token_id = self.vocab.get(self.start_token)
        self.end_token_id = self.vocab.get(self.end_token)
        if self.start_token_id is None or self.end_token_id is None:
            raise RuntimeError(
                f"{self.__class__.__name__} reasoning parser could not locate "
65
66
                "think start/end tokens in the tokenizer!"
            )
67
68

    def is_reasoning_end(self, input_ids: list[int]) -> bool:
69
        start_token_id = self.start_token_id
70
        end_token_id = self.end_token_id
71
72
73
74
75
76
77

        for i in range(len(input_ids) - 1, -1, -1):
            if input_ids[i] == start_token_id:
                return False
            if input_ids[i] == end_token_id:
                return True
        return False
78

79
80
81
82
83
84
    def is_reasoning_end_streaming(
        self, input_ids: list[int], delta_ids: list[int]
    ) -> bool:
        end_token_id = self.end_token_id
        return end_token_id in delta_ids

85
86
87
88
89
90
91
    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
        """
        Extract the content after the end tokens
        """
        if self.end_token_id not in input_ids[:-1]:
            return []
        else:
92
            return input_ids[input_ids.index(self.end_token_id) + 1 :]
93

94
    def extract_reasoning_streaming(
95
96
97
98
99
100
101
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
102
    ) -> DeltaMessage | None:
103
104
105
106
107
108
        """
        Extract reasoning content from a delta message.
        Handles streaming output where previous + delta = current.
        Uses token IDs for faster processing.
        """
        # Skip single special tokens
109
110
111
        if len(delta_token_ids) == 1 and (
            delta_token_ids[0] in [self.start_token_id, self.end_token_id]
        ):
112
113
114
115
116
117
118
119
120
            return None

        # Check if start token is present in previous or delta.
        # Keep compatibility with models that don't generate start tokens.
        if self.start_token_id in previous_token_ids:
            if self.end_token_id in delta_token_ids:
                # start token in previous, end token in delta,
                # extract reasoning content
                end_index = delta_text.find(self.end_token)
121
                reasoning = delta_text[:end_index]
122
                content = delta_text[end_index + len(self.end_token) :]
123
                return DeltaMessage(
124
                    reasoning=reasoning, content=content if content else None
125
126
127
128
129
130
131
132
                )
            elif self.end_token_id in previous_token_ids:
                # start token in previous, end token in previous,
                # reasoning content continues
                return DeltaMessage(content=delta_text)
            else:
                # start token in previous, no end token in previous or delta,
                # reasoning content continues
133
                return DeltaMessage(reasoning=delta_text)
134
135
136
137
138
139
        elif self.start_token_id in delta_token_ids:
            if self.end_token_id in delta_token_ids:
                # start token in delta, end token in delta,
                # extract reasoning content
                start_index = delta_text.find(self.start_token)
                end_index = delta_text.find(self.end_token)
140
                reasoning = delta_text[start_index + len(self.start_token) : end_index]
141
                content = delta_text[end_index + len(self.end_token) :]
142
                return DeltaMessage(
143
                    reasoning=reasoning, content=content if content else None
144
145
146
147
                )
            else:
                # start token in delta, no end token in delta,
                # reasoning content continues
148
                return DeltaMessage(reasoning=delta_text)
149
150
151
152
        else:
            # not find thinking start token
            return DeltaMessage(content=delta_text)

153
    def extract_reasoning(
154
155
        self, model_output: str, request: ChatCompletionRequest | ResponsesRequest
    ) -> tuple[str | None, str | None]:
156
157
        """
        Extract reasoning content from the model output.
158

159
160
161
162
163
164
        This is the base implementation that works for most models.
        Subclasses can override this method for specific behavior.
        """
        # Check if the start token is present in the model output, remove it
        # if it is present.
        model_output_parts = model_output.partition(self.start_token)
165
166
167
        model_output = (
            model_output_parts[2] if model_output_parts[1] else model_output_parts[0]
        )
168
169
170
171
172
173

        # For models that may not generate start token,
        # assume the reasoning content is always at the start.
        if self.end_token not in model_output:
            return model_output, None
        else:
174
            reasoning, _, content = model_output.partition(self.end_token)
175
176
            # If generation stops right after end-of-think, return null content
            final_content = content or None
177
            return reasoning, final_content