gptoss_reasoning_parser.py 3.26 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Sequence

from transformers import PreTrainedTokenizerBase

8
from vllm.entrypoints.harmony_utils import parse_chat_output
9
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
10
11
12
13
14
15
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager

logger = init_logger(__name__)


16
@ReasoningParserManager.register_module("openai_gptoss")
17
18
19
20
21
22
23
24
class GptOssReasoningParser(ReasoningParser):
    """
    Reasoning parser for GptOss model.

    The GptOss model uses harmony to extract reasoning content and this parser
    is only used for detecting the end of the reasoning content.
    """

25
26
    def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
        super().__init__(tokenizer, *args, **kwargs)
27
        self.reasoning_end_token_ids = self.model_tokenizer.encode(
28
29
            "<|start|>assistant<|channel|>final<|message|>"
        )
30
31
32
33
34
35
36

    def is_reasoning_end(self, input_ids: list[int]) -> bool:
        end_token_ids = self.reasoning_end_token_ids
        assert len(end_token_ids) > 0, "reasoning_end_token_ids is empty"
        # Check if the end sequence is present in the input_ids.
        # We search from the end of input_ids to find the last match.
        for i in range(len(input_ids) - len(end_token_ids), -1, -1):
37
            if input_ids[i : i + len(end_token_ids)] == end_token_ids:
38
39
40
41
                return True
        return False

    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
42
43
44
45
        _, content, _ = parse_chat_output(input_ids)
        if content is None:
            return []
        return self.model_tokenizer.encode(content)
46
47
48
49
50
51
52
53
54

    def extract_reasoning_content_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
55
    ) -> DeltaMessage | None:
56
57
        prev_reasoning, prev_content, _ = parse_chat_output(list(previous_token_ids))
        cur_reasoning, cur_content, _ = parse_chat_output(list(current_token_ids))
58
59
60
61
62
        reasoning_delta = None
        content_delta = None
        if cur_reasoning is not None:
            prev_r = prev_reasoning or ""
            if cur_reasoning.startswith(prev_r):
63
                reasoning_delta = cur_reasoning[len(prev_r) :] or None
64
65
66
67
68
            else:
                reasoning_delta = cur_reasoning
        if cur_content is not None:
            prev_c = prev_content or ""
            if cur_content.startswith(prev_c):
69
                content_delta = cur_content[len(prev_c) :] or None
70
71
72
73
            else:
                content_delta = cur_content
        if reasoning_delta is None and content_delta is None:
            return None
74
        return DeltaMessage(reasoning_content=reasoning_delta, content=content_delta)
75
76

    def extract_reasoning_content(
77
78
79
        self,
        model_output: str,
        request: ChatCompletionRequest,
80
    ) -> tuple[str | None, str | None]:
81
82
83
        raise NotImplementedError(
            "gpt-oss has a special branch for parsing reasoning in non-streaming mode. This method shouldn't be used."  # noqa: E501
        )