gptoss_reasoning_parser.py 5.71 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import json
4
5
6
7
from collections.abc import Sequence

from transformers import PreTrainedTokenizerBase

8
from vllm.entrypoints.harmony_utils import parse_chat_output
9
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
10
from vllm.entrypoints.tool_server import ToolServer
11
12
13
14
15
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager

logger = init_logger(__name__)

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
no_func_reaonsing_tag = {
    "type": "structural_tag",
    "format": {
        "type": "triggered_tags",
        "tags": [
            {
                "begin": "<|channel|>analysis<|message|>",
                "content": {"type": "any_text"},
                "end": "<|end|>",
            }
        ],
        "triggers": ["<|channel|>analysis"],
        "stop_after_first": False,
    },
}


def from_builtin_tool_to_tag(tool: str) -> list[dict]:
    tag = [
        {
            "begin": f"<|channel|>commentary to={tool}",
            "content": {"type": "any_text"},
            "end": "<|end|>",
        },
        {
            "begin": f"<|channel|>analysis to={tool}",
            "content": {"type": "any_text"},
            "end": "<|end|>",
        },
    ]
    return tag


def tag_with_builtin_funcs(no_func_reaonsing_tag, builtin_tool_list: list[str]) -> dict:
    import copy

    new_tag = copy.deepcopy(no_func_reaonsing_tag)
    new_tag["format"]["triggers"].append("<|channel|>commentary to=")

    for tool in builtin_tool_list:
        new_tag["format"]["tags"].extend(from_builtin_tool_to_tag(tool))
    return new_tag

59

60
@ReasoningParserManager.register_module("openai_gptoss")
61
62
63
64
65
66
67
68
class GptOssReasoningParser(ReasoningParser):
    """
    Reasoning parser for GptOss model.

    The GptOss model uses harmony to extract reasoning content and this parser
    is only used for detecting the end of the reasoning content.
    """

69
70
    def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
        super().__init__(tokenizer, *args, **kwargs)
71
        self.reasoning_end_token_ids = self.model_tokenizer.encode(
72
73
            "<|start|>assistant<|channel|>final<|message|>"
        )
74
75
76
77
78
79
80

    def is_reasoning_end(self, input_ids: list[int]) -> bool:
        end_token_ids = self.reasoning_end_token_ids
        assert len(end_token_ids) > 0, "reasoning_end_token_ids is empty"
        # Check if the end sequence is present in the input_ids.
        # We search from the end of input_ids to find the last match.
        for i in range(len(input_ids) - len(end_token_ids), -1, -1):
81
            if input_ids[i : i + len(end_token_ids)] == end_token_ids:
82
83
84
85
                return True
        return False

    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
86
87
88
89
        _, content, _ = parse_chat_output(input_ids)
        if content is None:
            return []
        return self.model_tokenizer.encode(content)
90
91
92
93
94
95
96
97
98

    def extract_reasoning_content_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
99
    ) -> DeltaMessage | None:
100
101
        prev_reasoning, prev_content, _ = parse_chat_output(list(previous_token_ids))
        cur_reasoning, cur_content, _ = parse_chat_output(list(current_token_ids))
102
103
104
105
106
        reasoning_delta = None
        content_delta = None
        if cur_reasoning is not None:
            prev_r = prev_reasoning or ""
            if cur_reasoning.startswith(prev_r):
107
                reasoning_delta = cur_reasoning[len(prev_r) :] or None
108
109
110
111
112
            else:
                reasoning_delta = cur_reasoning
        if cur_content is not None:
            prev_c = prev_content or ""
            if cur_content.startswith(prev_c):
113
                content_delta = cur_content[len(prev_c) :] or None
114
115
116
117
            else:
                content_delta = cur_content
        if reasoning_delta is None and content_delta is None:
            return None
118
        return DeltaMessage(reasoning_content=reasoning_delta, content=content_delta)
119
120

    def extract_reasoning_content(
121
122
123
        self,
        model_output: str,
        request: ChatCompletionRequest,
124
    ) -> tuple[str | None, str | None]:
125
126
127
        raise NotImplementedError(
            "gpt-oss has a special branch for parsing reasoning in non-streaming mode. This method shouldn't be used."  # noqa: E501
        )
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157

    # This function prepares the structural tag to format reasoning output
    def prepare_structured_tag(
        self, original_tag: str | None, tool_server: ToolServer | None
    ) -> str:
        if original_tag is None:
            if tool_server is None:
                return json.dumps(no_func_reaonsing_tag)
            else:
                builtin_tool_list: list[str] = []
                if tool_server.has_tool("browser"):
                    builtin_tool_list.append("browser")
                if tool_server.has_tool("python"):
                    builtin_tool_list.append("python")
                if tool_server.has_tool("container"):
                    builtin_tool_list.append("container")

                if len(builtin_tool_list) > 0:
                    logger.info("Builtin_tool_list: %s", builtin_tool_list)
                    func_tag = json.dumps(
                        tag_with_builtin_funcs(no_func_reaonsing_tag, builtin_tool_list)
                    )
                else:
                    logger.info("Builtin_tool_list is empty")
                    func_tag = json.dumps(no_func_reaonsing_tag)

                return func_tag
        else:
            # There is potential risk for appending the tag to the original tag
            return original_tag