gptoss_reasoning_parser.py 8.25 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import json
4
from collections.abc import Iterable, Sequence
5
from typing import TYPE_CHECKING
6
7
8

from transformers import PreTrainedTokenizerBase

9
from vllm.entrypoints.mcp.tool_server import ToolServer
10
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
11
from vllm.entrypoints.openai.parser.harmony_utils import parse_chat_output
12
from vllm.logger import init_logger
13
from vllm.reasoning import ReasoningParser
14

15
16
17
18
if TYPE_CHECKING:
    from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
    from vllm.entrypoints.openai.responses.protocol import ResponsesRequest

19
20
logger = init_logger(__name__)

21
no_func_reasoning_tag = {
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    "type": "structural_tag",
    "format": {
        "type": "triggered_tags",
        "tags": [
            {
                "begin": "<|channel|>analysis<|message|>",
                "content": {"type": "any_text"},
                "end": "<|end|>",
            }
        ],
        "triggers": ["<|channel|>analysis"],
        "stop_after_first": False,
    },
}


def from_builtin_tool_to_tag(tool: str) -> list[dict]:
    tag = [
        {
            "begin": f"<|channel|>commentary to={tool}",
            "content": {"type": "any_text"},
            "end": "<|end|>",
        },
        {
            "begin": f"<|channel|>analysis to={tool}",
            "content": {"type": "any_text"},
            "end": "<|end|>",
        },
    ]
    return tag


54
def tag_with_builtin_funcs(no_func_reasoning_tag, builtin_tool_list: list[str]) -> dict:
55
56
    import copy

57
    new_tag = copy.deepcopy(no_func_reasoning_tag)
58
59
60
61
62
63
    new_tag["format"]["triggers"].append("<|channel|>commentary to=")

    for tool in builtin_tool_list:
        new_tag["format"]["tags"].extend(from_builtin_tool_to_tag(tool))
    return new_tag

64
65
66
67
68
69
70
71
72

class GptOssReasoningParser(ReasoningParser):
    """
    Reasoning parser for GptOss model.

    The GptOss model uses harmony to extract reasoning content and this parser
    is only used for detecting the end of the reasoning content.
    """

73
74
    def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
        super().__init__(tokenizer, *args, **kwargs)
75
76
77
78
        # The model can output some special tokens between "final" and "<|message|>"
        # So we need to look for both sequences to determine the end of reasoning.
        self.reasoning_end_token_ids_prefix = self.model_tokenizer.encode(
            "<|channel|>final"
79
        )
80
        self.reasoning_end_token_ids_suffix = self.model_tokenizer.encode("<|message|>")
81
82
        # We also need to check for the <|end|> token to avoid false positives from
        # previous messages in multi-turn conversations.
83
        self.eom_token_id = self.vocab["<|end|>"]
84
        self.reasoning_max_num_between_tokens = 20
85

86
    def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
87
88
89
90
        end_token_ids_prefix = self.reasoning_end_token_ids_prefix
        end_token_ids_suffix = self.reasoning_end_token_ids_suffix
        assert len(end_token_ids_prefix) > 0, "reasoning_end_token_ids_prefix is empty"
        assert len(end_token_ids_suffix) > 0, "reasoning_end_token_ids_suffix is empty"
91
92
        # Check if the end sequence is present in the input_ids.
        # We search from the end of input_ids to find the last match.
93
        for i in range(len(input_ids) - len(end_token_ids_prefix), -1, -1):
94
95
96
97
98
99
            if input_ids[i] == self.eom_token_id:
                # We looped backwards far enough to find the end of a previous message,
                # which means we have searched the entirety of the current message
                # and can exit early without searching further back into prior
                # messages of the conversation.
                return False
100
101
102
103
104
105
106
107
108
109
110
111
112
            if input_ids[i : i + len(end_token_ids_prefix)] == end_token_ids_prefix:
                # We have found the prefix, now we look for the suffix after the prefix.
                suffix_start = i + len(end_token_ids_prefix)
                for j in range(
                    suffix_start, len(input_ids) - len(end_token_ids_suffix) + 1
                ):
                    if j - suffix_start >= self.reasoning_max_num_between_tokens:
                        break
                    if (
                        input_ids[j : j + len(end_token_ids_suffix)]
                        == end_token_ids_suffix
                    ):
                        return True
113
114
        return False

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    def is_reasoning_end_streaming(
        self, input_ids: Sequence[int], delta_ids: Iterable[int]
    ) -> bool:
        # The pattern window covers the end-of-reasoning marker itself.
        # We add len(delta_ids) so that under speculative decoding (where
        # a single step can accept many tokens) the entire accepted chunk
        # is always inside the scan region.
        delta_ids = tuple(delta_ids)
        pattern_len = (
            len(self.reasoning_end_token_ids_prefix)
            + self.reasoning_max_num_between_tokens
            + len(self.reasoning_end_token_ids_suffix)
        )
        window = pattern_len + len(delta_ids)
        n = len(input_ids)
        if n <= window:
            return self.is_reasoning_end(input_ids)
        return self.is_reasoning_end(input_ids[n - window :])

134
    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
135
136
137
138
        _, content, _ = parse_chat_output(input_ids)
        if content is None:
            return []
        return self.model_tokenizer.encode(content)
139

140
    def extract_reasoning_streaming(
141
142
143
144
145
146
147
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
148
    ) -> DeltaMessage | None:
149
150
        prev_reasoning, prev_content, _ = parse_chat_output(list(previous_token_ids))
        cur_reasoning, cur_content, _ = parse_chat_output(list(current_token_ids))
151
152
153
154
155
        reasoning_delta = None
        content_delta = None
        if cur_reasoning is not None:
            prev_r = prev_reasoning or ""
            if cur_reasoning.startswith(prev_r):
156
                reasoning_delta = cur_reasoning[len(prev_r) :] or None
157
158
159
160
161
            else:
                reasoning_delta = cur_reasoning
        if cur_content is not None:
            prev_c = prev_content or ""
            if cur_content.startswith(prev_c):
162
                content_delta = cur_content[len(prev_c) :] or None
163
164
165
166
            else:
                content_delta = cur_content
        if reasoning_delta is None and content_delta is None:
            return None
167
        return DeltaMessage(reasoning=reasoning_delta, content=content_delta)
168

169
    def extract_reasoning(
170
171
        self,
        model_output: str,
172
        request: "ChatCompletionRequest | ResponsesRequest",
173
    ) -> tuple[str | None, str | None]:
174
175
176
        raise NotImplementedError(
            "gpt-oss has a special branch for parsing reasoning in non-streaming mode. This method shouldn't be used."  # noqa: E501
        )
177
178
179
180

    # This function prepares the structural tag to format reasoning output
    def prepare_structured_tag(
        self, original_tag: str | None, tool_server: ToolServer | None
Ning Xie's avatar
Ning Xie committed
181
    ) -> str | None:
182
183
        if original_tag is None:
            if tool_server is None:
184
                return json.dumps(no_func_reasoning_tag)
185
186
187
188
189
190
191
192
193
194
195
196
            else:
                builtin_tool_list: list[str] = []
                if tool_server.has_tool("browser"):
                    builtin_tool_list.append("browser")
                if tool_server.has_tool("python"):
                    builtin_tool_list.append("python")
                if tool_server.has_tool("container"):
                    builtin_tool_list.append("container")

                if len(builtin_tool_list) > 0:
                    logger.info("Builtin_tool_list: %s", builtin_tool_list)
                    func_tag = json.dumps(
197
                        tag_with_builtin_funcs(no_func_reasoning_tag, builtin_tool_list)
198
199
200
                    )
                else:
                    logger.info("Builtin_tool_list is empty")
201
                    func_tag = json.dumps(no_func_reasoning_tag)
202
203
204
205
206

                return func_tag
        else:
            # There is potential risk for appending the tag to the original tag
            return original_tag