deepseek_r1_reasoning_parser.py 7.22 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
from collections.abc import Sequence
from typing import Optional, Union
5
6
7
8
9
10

from transformers import PreTrainedTokenizerBase

from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
                                              DeltaMessage)
from vllm.logger import init_logger
11
from vllm.reasoning import ReasoningParser, ReasoningParserManager
12
13
14
15
16
17
18
19
20

logger = init_logger(__name__)


@ReasoningParserManager.register_module("deepseek_r1")
class DeepSeekR1ReasoningParser(ReasoningParser):
    """
    Reasoning parser for DeepSeek R1 model.

21
    The DeepSeek R1 model uses <think>...</think> tokens to denote reasoning
22
23
24
    text. This parser extracts the reasoning content from the model output.
    """

25
26
27
28
29
30
    start_token_id: int
    end_token_id: int

    start_token: str = "<think>"
    end_token: str = "</think>"

31
32
33
34
35
36
37
38
    def __init__(self, tokenizer: PreTrainedTokenizerBase):
        super().__init__(tokenizer)

        if not self.model_tokenizer:
            raise ValueError(
                "The model tokenizer must be passed to the ReasoningParser "
                "constructor during construction.")

39
40
41
        self.start_token_id = self.vocab.get(self.start_token)
        self.end_token_id = self.vocab.get(self.end_token)
        if self.start_token_id is None or self.end_token_id is None:
42
43
44
45
            raise RuntimeError(
                "DeepSeek R1 reasoning parser could not locate think start/end "
                "tokens in the tokenizer!")

46
    def is_reasoning_end(self, input_ids: list[int]) -> bool:
47
        return self.end_token_id in input_ids
48
49
50
51
52

    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
        """
        Extract the content after the end tokens
        """
53
        if self.end_token_id not in input_ids[:-1]:
54
55
            return []
        else:
56
            return input_ids[input_ids.index(self.end_token_id) + 1:]
57

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    def extract_reasoning_content_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
    ) -> Union[DeltaMessage, None]:
        """
        Extract reasoning content from a delta message.
        Handles streaming output where previous + delta = current.
        Uses token IDs for faster processing.
        For text <think>abc</think>xyz:
        - 'abc' goes to reasoning_content
        - 'xyz' goes to content
        """
        # Skip single special tokens
        if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
77
                self.start_token_id, self.end_token_id
78
79
80
        ]):
            return None

81
82
        # Check if <think> is present in previous or delta.
        # Keep compatibility with models that don't generate <think> tokens.
83
84
        if self.start_token_id in previous_token_ids:
            if self.end_token_id in delta_token_ids:
85
86
                # <think> in previous, </think> in delta,
                # extract reasoning content
87
                end_index = delta_text.find(self.end_token)
88
                reasoning_content = delta_text[:end_index]
89
90
91
92
93
94
                content = delta_text[end_index + len(self.end_token):]
                return DeltaMessage(
                    reasoning_content=reasoning_content,
                    content=content if content else None,
                )
            elif self.end_token_id in previous_token_ids:
95
96
97
98
99
100
101
                # <think> in previous, </think> in previous,
                # reasoning content continues
                return DeltaMessage(content=delta_text)
            else:
                # <think> in previous, no </think> in previous or delta,
                # reasoning content continues
                return DeltaMessage(reasoning_content=delta_text)
102
103
        elif self.start_token_id in delta_token_ids:
            if self.end_token_id in delta_token_ids:
104
                # <think> in delta, </think> in delta, extract reasoning content
105
106
                start_index = delta_text.find(self.start_token)
                end_index = delta_text.find(self.end_token)
107
                reasoning_content = delta_text[start_index +
108
109
110
111
112
113
                                               len(self.start_token):end_index]
                content = delta_text[end_index + len(self.end_token):]
                return DeltaMessage(
                    reasoning_content=reasoning_content,
                    content=content if content else None,
                )
114
115
116
117
118
            else:
                # <think> in delta, no </think> in delta,
                # reasoning content continues
                return DeltaMessage(reasoning_content=delta_text)
        else:
119
120
121
            # No <think> in previous or delta, also need to check for </think>.
            # Because the model may have generated </think> without <think>
            # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
122
            if self.end_token_id in delta_token_ids:
123
124
                # </think> in delta with more tokens,
                # extract reasoning content and content
125
                end_index = delta_text.find(self.end_token)
126
                reasoning_content = delta_text[:end_index]
127
128
129
130
131
132
                content = delta_text[end_index + len(self.end_token):]
                return DeltaMessage(
                    reasoning_content=reasoning_content,
                    content=content if content else None,
                )
            elif self.end_token_id in previous_token_ids:
133
134
135
136
137
                # </think> in previous, thinking content ends
                return DeltaMessage(content=delta_text)
            else:
                # no </think> in previous or delta, reasoning content continues
                return DeltaMessage(reasoning_content=delta_text)
138
139
140

    def extract_reasoning_content(
            self, model_output: str, request: ChatCompletionRequest
141
    ) -> tuple[Optional[str], Optional[str]]:
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        """
        Extract reasoning content from the model output.

        For text <think>abc</think>xyz:
        - 'abc' goes to reasoning_content
        - 'xyz' goes to content

        Returns:
            tuple[Optional[str], Optional[str]]: reasoning content and content
        """

        # Check if the start token is present in the model output, remove it
        # if it is present.
        model_output_parts = model_output.partition(self.start_token)
        model_output = model_output_parts[2] if model_output_parts[
            1] else model_output_parts[0]

159
160
161
        # DeepSeek R1 doesn't generate <think> now.
        # Thus we assume the reasoning content is always at the start.
        # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
162
        if self.end_token not in model_output:
163
            return model_output, None
164
        else:
165
166
167
168
169
170
171
172
            reasoning_content, _, content = model_output.partition(
                self.end_token)
            # If the end token is not found, return the model output as is.
            # It should not happen since we already checked for the presence
            # of the end token.
            # If generation stops right after end-of-think, return null content
            final_content = content or None
            return reasoning_content, final_content