step3_reasoning_parser.py 4.12 KB
Newer Older
Song's avatar
Song committed
1
2
3
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Sequence
from typing import Optional, Union

import regex as re
from transformers import PreTrainedTokenizerBase

10
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
Song's avatar
Song committed
11
12
13
14
15
16
17
18
19
20
21
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager

logger = init_logger(__name__)


@ReasoningParserManager.register_module("step3")
class Step3ReasoningParser(ReasoningParser):
    """
    Reasoning parser for Step3 model.

22
    The Step3 model uses </think> token to denote the end of reasoning
Song's avatar
Song committed
23
24
25
    text. This parser extracts all content before </think> as reasoning content.
    """

26
27
    def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
        super().__init__(tokenizer, *args, **kwargs)
Song's avatar
Song committed
28
29
        self.think_end_token = "</think>"

30
        self.reasoning_regex = re.compile(rf"(.*?){self.think_end_token}", re.DOTALL)
Song's avatar
Song committed
31
32
33
34

        if not self.model_tokenizer:
            raise ValueError(
                "The model tokenizer must be passed to the ReasoningParser "
35
36
                "constructor during construction."
            )
Song's avatar
Song committed
37
38
39
40
41

        self.think_end_token_id = self.vocab.get(self.think_end_token)
        if self.think_end_token_id is None:
            raise RuntimeError(
                "Step3 reasoning parser could not locate think end "
42
43
                "token in the tokenizer!"
            )
Song's avatar
Song committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

    def extract_reasoning_content_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
    ) -> Union[DeltaMessage, None]:
        """
        Extract reasoning content from a delta message.
        Handles streaming output where previous + delta = current.
        Uses token IDs for faster processing.
        For text "abc</think>xyz":
        - 'abc' goes to reasoning_content
        - 'xyz' goes to content
        """
        # Skip single special token
63
        if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
Song's avatar
Song committed
64
65
66
67
68
69
            return None

        if self.think_end_token_id in delta_token_ids:
            # </think> in delta, extract reasoning content and remaining content
            end_index = delta_text.find(self.think_end_token)
            reasoning_content = delta_text[:end_index]
70
71
72
73
74
            content = delta_text[end_index + len(self.think_end_token) :]
            return DeltaMessage(
                reasoning_content=reasoning_content,
                content=content if content else None,
            )
Song's avatar
Song committed
75
76
77
78
79
80
81
82
        elif self.think_end_token_id in previous_token_ids:
            # </think> already seen in previous text, everything is content
            return DeltaMessage(content=delta_text)
        else:
            # No </think> seen yet, everything is reasoning
            return DeltaMessage(reasoning_content=delta_text)

    def extract_reasoning_content(
83
        self, model_output: str, request: ChatCompletionRequest
Song's avatar
Song committed
84
85
86
87
88
89
90
91
92
93
94
    ) -> tuple[Optional[str], Optional[str]]:
        # Check if the model output contains the </think> token
        if self.think_end_token not in model_output:
            # If no </think> token, everything is reasoning content
            return model_output, None
        else:
            # Find the first occurrence of </think>
            end_index = model_output.find(self.think_end_token)
            reasoning_content = model_output[:end_index]

            # Content after </think> token
95
            content = model_output[end_index + len(self.think_end_token) :]
Song's avatar
Song committed
96
97
98
99
100
101
102
103
104
105
106
107
108

            if len(content) == 0:
                content = None

            return reasoning_content, content

    def is_reasoning_end(self, input_ids: list[int]) -> bool:
        return self.think_end_token_id in input_ids

    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
        if self.think_end_token_id not in input_ids[:-1]:
            return []
        else:
109
            return input_ids[input_ids.index(self.think_end_token_id) + 1 :]