step3_reasoning_parser.py 4.2 KB
Newer Older
Song's avatar
Song committed
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Sequence

import regex as re
from transformers import PreTrainedTokenizerBase

9
10
11
12
from vllm.entrypoints.openai.chat_completion.protocol import (
    ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
Song's avatar
Song committed
13
from vllm.logger import init_logger
14
from vllm.reasoning import ReasoningParser
Song's avatar
Song committed
15
16
17
18
19
20
21
22

logger = init_logger(__name__)


class Step3ReasoningParser(ReasoningParser):
    """
    Reasoning parser for Step3 model.

23
    The Step3 model uses </think> token to denote the end of reasoning
Song's avatar
Song committed
24
25
26
    text. This parser extracts all content before </think> as reasoning content.
    """

27
28
    def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
        super().__init__(tokenizer, *args, **kwargs)
Song's avatar
Song committed
29
30
        self.think_end_token = "</think>"

31
        self.reasoning_regex = re.compile(rf"(.*?){self.think_end_token}", re.DOTALL)
Song's avatar
Song committed
32
33
34
35

        if not self.model_tokenizer:
            raise ValueError(
                "The model tokenizer must be passed to the ReasoningParser "
36
37
                "constructor during construction."
            )
Song's avatar
Song committed
38
39
40
41
42

        self.think_end_token_id = self.vocab.get(self.think_end_token)
        if self.think_end_token_id is None:
            raise RuntimeError(
                "Step3 reasoning parser could not locate think end "
43
44
                "token in the tokenizer!"
            )
Song's avatar
Song committed
45

46
    def extract_reasoning_streaming(
Song's avatar
Song committed
47
48
49
50
51
52
53
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
54
    ) -> DeltaMessage | None:
Song's avatar
Song committed
55
56
57
58
59
        """
        Extract reasoning content from a delta message.
        Handles streaming output where previous + delta = current.
        Uses token IDs for faster processing.
        For text "abc</think>xyz":
60
        - 'abc' goes to reasoning
Song's avatar
Song committed
61
62
63
        - 'xyz' goes to content
        """
        # Skip single special token
64
        if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
Song's avatar
Song committed
65
66
67
68
69
            return None

        if self.think_end_token_id in delta_token_ids:
            # </think> in delta, extract reasoning content and remaining content
            end_index = delta_text.find(self.think_end_token)
70
            reasoning = delta_text[:end_index]
71
72
            content = delta_text[end_index + len(self.think_end_token) :]
            return DeltaMessage(
73
                reasoning=reasoning,
74
75
                content=content if content else None,
            )
Song's avatar
Song committed
76
77
78
79
80
        elif self.think_end_token_id in previous_token_ids:
            # </think> already seen in previous text, everything is content
            return DeltaMessage(content=delta_text)
        else:
            # No </think> seen yet, everything is reasoning
81
            return DeltaMessage(reasoning=delta_text)
Song's avatar
Song committed
82

83
    def extract_reasoning(
84
        self, model_output: str, request: ChatCompletionRequest
85
    ) -> tuple[str | None, str | None]:
Song's avatar
Song committed
86
87
88
89
90
91
92
        # Check if the model output contains the </think> token
        if self.think_end_token not in model_output:
            # If no </think> token, everything is reasoning content
            return model_output, None
        else:
            # Find the first occurrence of </think>
            end_index = model_output.find(self.think_end_token)
93
            reasoning = model_output[:end_index]
Song's avatar
Song committed
94
95

            # Content after </think> token
96
            content = model_output[end_index + len(self.think_end_token) :]
Song's avatar
Song committed
97
98
99
100

            if len(content) == 0:
                content = None

101
            return reasoning, content
Song's avatar
Song committed
102
103
104
105

    def is_reasoning_end(self, input_ids: list[int]) -> bool:
        return self.think_end_token_id in input_ids

106
107
108
109
110
111
    def is_reasoning_end_streaming(
        self, input_ids: list[int], delta_ids: list[int]
    ) -> bool:
        end_token_id = self.think_end_token_id
        return end_token_id in delta_ids

Song's avatar
Song committed
112
113
114
115
    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
        if self.think_end_token_id not in input_ids[:-1]:
            return []
        else:
116
            return input_ids[input_ids.index(self.think_end_token_id) + 1 :]