glm4_moe_reasoning_parser.py 6.96 KB
Newer Older
Yuxuan Zhang's avatar
Yuxuan Zhang committed
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Sequence
from typing import Optional, Union

from transformers import PreTrainedTokenizerBase

9
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
Yuxuan Zhang's avatar
Yuxuan Zhang committed
10
11
12
13
14
15
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager

logger = init_logger(__name__)


16
@ReasoningParserManager.register_module("glm45")
Yuxuan Zhang's avatar
Yuxuan Zhang committed
17
18
19
20
21
22
23
24
25
26
27
class Glm4MoeModelReasoningParser(ReasoningParser):
    """
    Reasoning parser for the Glm4MoeModel model.

    The Glm4MoeModel model uses <think>...</think> tokens to denote reasoning
    text within its output. The model provides a strict switch to disable
    reasoning output via the 'enable_thinking=False' parameter. This parser
    extracts the reasoning content enclosed by <think> and </think> tokens
    from the model's output.
    """

28
29
    def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
        super().__init__(tokenizer, *args, **kwargs)
Yuxuan Zhang's avatar
Yuxuan Zhang committed
30
31
        self.think_start_token = "<think>"
        self.think_end_token = "</think>"
32
        self.assistant_token = "<|assistant|>"
Yuxuan Zhang's avatar
Yuxuan Zhang committed
33
34
35
36

        if not self.model_tokenizer:
            raise ValueError(
                "The model tokenizer must be passed to the ReasoningParser "
37
38
                "constructor during construction."
            )
Yuxuan Zhang's avatar
Yuxuan Zhang committed
39
40
41

        self.think_start_token_id = self.vocab.get(self.think_start_token)
        self.think_end_token_id = self.vocab.get(self.think_end_token)
42
        self.assistant_token_id = self.vocab.get(self.assistant_token)
43
44
45
46
47
        if (
            self.think_start_token_id is None
            or self.think_end_token_id is None
            or self.assistant_token_id is None
        ):
Yuxuan Zhang's avatar
Yuxuan Zhang committed
48
49
            raise RuntimeError(
                "Glm4MoeModel reasoning parser could not locate "
50
51
                "think start/end or assistant tokens in the tokenizer!"
            )
Yuxuan Zhang's avatar
Yuxuan Zhang committed
52
53

    def is_reasoning_end(self, input_ids: list[int]) -> bool:
54
55
56
57
58
59
60
61
62
63
64
        """
        GLM's chat template has <think></think> tokens after every
        <|assistant|> token. Thus, we need to check if </think> is
        after the most recent <|assistant|> token (if present).
        """
        for token_id in input_ids[::-1]:
            if token_id == self.think_end_token_id:
                return True
            elif token_id == self.assistant_token_id:
                return False
        return False
Yuxuan Zhang's avatar
Yuxuan Zhang committed
65
66
67
68
69
70
71
72

    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
        """
        Extract the content after the end tokens
        """
        if self.think_end_token_id not in input_ids[:-1]:
            return []
        else:
73
            return input_ids[input_ids.index(self.think_end_token_id) + 1 :]
Yuxuan Zhang's avatar
Yuxuan Zhang committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

    def extract_reasoning_content_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
    ) -> Union[DeltaMessage, None]:
        """
        Extract reasoning content from a delta message.
        Handles streaming output where previous + delta = current.
        Uses token IDs for faster processing.
        For text <think>abc</think>xyz:
        - 'abc' goes to reasoning_content
        - 'xyz' goes to content
        """
        # Skip single special tokens
93
94
95
        if len(delta_token_ids) == 1 and (
            delta_token_ids[0] in [self.think_start_token_id, self.think_end_token_id]
        ):
Yuxuan Zhang's avatar
Yuxuan Zhang committed
96
97
98
99
100
101
102
103
            return None

        if self.think_start_token_id in previous_token_ids:
            if self.think_end_token_id in delta_token_ids:
                # <think> in previous, </think> in delta,
                # extract reasoning content
                end_index = delta_text.find(self.think_end_token)
                reasoning_content = delta_text[:end_index]
104
105
106
107
108
                content = delta_text[end_index + len(self.think_end_token) :]
                return DeltaMessage(
                    reasoning_content=reasoning_content,
                    content=content if content else None,
                )
Yuxuan Zhang's avatar
Yuxuan Zhang committed
109
110
111
112
113
114
115
116
117
118
119
120
121
            elif self.think_end_token_id in previous_token_ids:
                # <think> in previous, </think> in previous,
                # reasoning content continues
                return DeltaMessage(content=delta_text)
            else:
                # <think> in previous, no </think> in previous or delta,
                # reasoning content continues
                return DeltaMessage(reasoning_content=delta_text)
        elif self.think_start_token_id in delta_token_ids:
            if self.think_end_token_id in delta_token_ids:
                # <think> in delta, </think> in delta, extract reasoning content
                start_index = delta_text.find(self.think_start_token)
                end_index = delta_text.find(self.think_end_token)
122
123
124
125
126
127
128
129
                reasoning_content = delta_text[
                    start_index + len(self.think_start_token) : end_index
                ]
                content = delta_text[end_index + len(self.think_end_token) :]
                return DeltaMessage(
                    reasoning_content=reasoning_content,
                    content=content if content else None,
                )
Yuxuan Zhang's avatar
Yuxuan Zhang committed
130
131
132
133
134
135
136
137
138
            else:
                # <think> in delta, no </think> in delta,
                # reasoning content continues
                return DeltaMessage(reasoning_content=delta_text)
        else:
            # thinking is disabled, just content
            return DeltaMessage(content=delta_text)

    def extract_reasoning_content(
139
        self, model_output: str, request: ChatCompletionRequest
Yuxuan Zhang's avatar
Yuxuan Zhang committed
140
141
142
143
144
145
146
147
148
149
150
151
152
    ) -> tuple[Optional[str], Optional[str]]:
        """
        Extract reasoning content from the model output.

        For text <think>abc</think>xyz:
        - 'abc' goes to reasoning_content
        - 'xyz' goes to content

        Returns:
            tuple[Optional[str], Optional[str]]: reasoning content and content
        """

        # Check if the model output contains the <think> and </think> tokens.
153
154
155
156
        if (
            self.think_start_token not in model_output
            or self.think_end_token not in model_output
        ):
Yuxuan Zhang's avatar
Yuxuan Zhang committed
157
158
159
160
            return None, model_output
        # Check if the <think> is present in the model output, remove it
        # if it is present.
        model_output_parts = model_output.partition(self.think_start_token)
161
162
163
        model_output = (
            model_output_parts[2] if model_output_parts[1] else model_output_parts[0]
        )
Yuxuan Zhang's avatar
Yuxuan Zhang committed
164
165
166
167
168
169
        # Check if the model output contains the </think> tokens.
        # If the end token is not found, return the model output as is.
        if self.think_end_token not in model_output:
            return None, model_output

        # Extract reasoning content from the model output.
170
        reasoning_content, _, content = model_output.partition(self.think_end_token)
Yuxuan Zhang's avatar
Yuxuan Zhang committed
171
172
173

        final_content = content or None
        return reasoning_content, final_content