ernie45_reasoning_parser.py 6.54 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Sequence
5
from typing import TYPE_CHECKING
6
7
8

from transformers import PreTrainedTokenizerBase

9
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
10
11
12
from vllm.logger import init_logger
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser

13
14
15
16
if TYPE_CHECKING:
    from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
    from vllm.entrypoints.openai.responses.protocol import ResponsesRequest

17
18
19
20
21
22
logger = init_logger(__name__)


class Ernie45ReasoningParser(BaseThinkingReasoningParser):
    """
    Reasoning parser for Ernie45 thinking model.
Jiayi Yan's avatar
Jiayi Yan committed
23
    The Ernie45 thinking model output format is
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
        abc\n</think>\n\n<response>\ndef\n</response>\n
    or  abc\n</think>\ndef
    """

    response_start_token: str = "<response>"
    response_end_token: str = "</response>"
    newline_token: str = "<0x0A>"

    @property
    def start_token(self) -> str:
        """The token that starts reasoning content."""
        return "<think>"

    @property
    def end_token(self) -> str:
        """The token that ends reasoning content."""
        return "</think>"

42
43
    def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
        super().__init__(tokenizer, *args, **kwargs)
44
45
46
47
48
49
50
51
52
53
54
55
56

        if not self.model_tokenizer:
            raise ValueError(
                "The model tokenizer must be passed to the ReasoningParser "
                "constructor during construction."
            )

        self.response_start_token_id = self.vocab.get(self.response_start_token)
        self.response_end_token_id = self.vocab.get(self.response_end_token)
        self.newline_token_id = self.vocab.get(self.newline_token)

        self.parser_token_ids = [self.end_token_id, self.response_end_token_id]

57
    def extract_reasoning_streaming(
58
59
60
61
62
63
64
65
66
67
68
69
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
    ) -> DeltaMessage | None:
        """
        Extract reasoning content from a delta message.
        Handles streaming output where previous + delta = current.
        Uses token IDs for faster processing.
Jiayi Yan's avatar
Jiayi Yan committed
70
        The Ernie45 thinking model output format is
71
72
            abc\n</think>\n\n<response>\ndef\n</response>\n
        or  abc\n</think>\ndef
73
        - 'abc' goes to reasoning
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        - 'def' goes to content
        """
        # Skip single special tokens
        if len(delta_token_ids) == 1 and (
            delta_token_ids[0]
            in [
                self.start_token_id,
                self.end_token_id,
                self.response_start_token_id,
                self.response_end_token_id,
            ]
        ):
            return None

        # No <think> in previous or delta, also need to check for </think>.
        # Because the model may have generated </think> without <think>
        if self.end_token_id in delta_token_ids:
            # </think> in delta with more tokens,
            # extract reasoning content and content
            think_end_index = delta_text.find(self.end_token)
94
            reasoning = delta_text[:think_end_index]
95
96
97
98
99
100
101
102
103
            content = delta_text[think_end_index + len(self.end_token) :]
            content = content.lstrip("\n")
            response_start_idx = content.find(self.response_start_token)
            response_end_idx = content.rfind(self.response_end_token)
            if response_start_idx != -1:
                content = content[response_start_idx + len(self.response_start_token) :]
            if response_end_idx != -1:
                content = content[:response_end_idx]
            return DeltaMessage(
104
                reasoning=reasoning,
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
                content=content if content else None,
            )
        elif self.end_token_id in previous_token_ids:
            # </think> in previous, thinking content ends
            content = delta_text
            if self.response_start_token_id in delta_token_ids:
                content = content.lstrip("\n")
                response_start_idx = content.find(self.response_start_token)
                content = content[response_start_idx + len(self.response_start_token) :]
                # if have </response>, remove it
                response_end_idx = content.rfind(self.response_end_token)
                if response_end_idx != -1:
                    content = content[:response_end_idx]
            elif self.response_end_token_id in delta_token_ids:
                response_end_idx = content.rfind(self.response_end_token)
                content = content[:response_end_idx]
            # remove \n after </think>  or </response>
            if previous_token_ids[-1] in self.parser_token_ids and (
                len(delta_token_ids) > 0 and delta_token_ids[0] == self.newline_token_id
            ):
                content = content.lstrip("\n")
            # remove \n after </think>\n
            if (
                len(previous_token_ids) > 1
                and previous_token_ids[-2] == self.end_token_id
            ) and (
                len(delta_token_ids) > 0 and delta_token_ids[0] == self.newline_token_id
            ):
                content = content.lstrip("\n")

            return DeltaMessage(content=content if content else None)
        else:
            # no </think> in previous or delta, reasoning content continues
138
            return DeltaMessage(reasoning=delta_text)
139

140
    def extract_reasoning(
141
        self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
142
143
144
    ) -> tuple[str | None, str | None]:
        """
        Extract reasoning content from the model output.
Jiayi Yan's avatar
Jiayi Yan committed
145
        The Ernie45 thinking model output format is
146
147
            abc\n</think>\n\n\n<response>\ndef\n</response>\n
        or  abc\n</think>\ndef
148
        - 'abc' goes to reasoning
149
150
151
152
        - 'def' goes to content
        Returns:
            tuple[Optional[str], Optional[str]]: reasoning content and content
        """
153
        reasoning, content = super().extract_reasoning(model_output, request)
154
155
156
157
158
159
160
161
        if content:
            start_idx = content.find(self.response_start_token)
            end_idx = content.rfind(self.response_end_token)
            # Simultaneously existing and in the correct order
            if start_idx != -1 and end_idx != -1 and start_idx < end_idx:
                content = content[start_idx + len(self.response_start_token) : end_idx]
        final_content = content or None

162
        return reasoning, final_content