minimax_m2_reasoning_parser.py 3.79 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Sequence

from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    DeltaMessage,
chenych's avatar
chenych committed
9
    ResponsesRequest,
chenych's avatar
chenych committed
10
11
)
from vllm.logger import init_logger
chenych's avatar
chenych committed
12
from vllm.reasoning import ReasoningParser, ReasoningParserManager
chenych's avatar
chenych committed
13
14
15
16
17
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
from vllm.transformers_utils.tokenizer import AnyTokenizer

logger = init_logger(__name__)

chenych's avatar
chenych committed
18
@ReasoningParserManager.register_module("minimax_m2")
chenych's avatar
chenych committed
19
20
21
class MiniMaxM2ReasoningParser(BaseThinkingReasoningParser):
    """
    Reasoning parser for MiniMax M2 model.
chenych's avatar
chenych committed
22
23
24
25

    MiniMax M2 models don't generate <think> start token, only </think> end
    token. All content before </think> is reasoning, content after is the
    actual response.
chenych's avatar
chenych committed
26
27
28
29
30
31
32
33
34
35
36
37
    """

    @property
    def start_token(self) -> str:
        """The token that starts reasoning content."""
        return "<think>"

    @property
    def end_token(self) -> str:
        """The token that ends reasoning content."""
        return "</think>"

chenych's avatar
chenych committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    def extract_reasoning_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
    ) -> DeltaMessage | None:
        """
        Extract reasoning content from a delta message for streaming.

        MiniMax M2 models don't generate <think> start token, so we assume
        all content is reasoning until we encounter the </think> end token.
        """
        # Skip single end token
        if len(delta_token_ids) == 1 and delta_token_ids[0] == self.end_token_id:
            return None

        # Check if end token has already appeared in previous tokens
        # meaning we're past the reasoning phase
        if self.end_token_id in previous_token_ids:
            # We're past the reasoning phase, this is content
            return DeltaMessage(content=delta_text)

        # Check if end token is in delta tokens
        if self.end_token_id in delta_token_ids:
            # End token in delta, split reasoning and content
            end_index = delta_text.find(self.end_token)
            reasoning = delta_text[:end_index]
            content = delta_text[end_index + len(self.end_token) :]
            return DeltaMessage(
                reasoning=reasoning if reasoning else None,
                content=content if content else None,
            )

        # No end token yet, all content is reasoning
        return DeltaMessage(reasoning=delta_text)

chenych's avatar
chenych committed
77
78
79
80
81
82
83
84
85
86

class MiniMaxM2AppendThinkReasoningParser(ReasoningParser):
    """
    Reasoning parser for MiniMax M2 model.
    """

    def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs):
        super().__init__(tokenizer, *args, **kwargs)
        self.end_token_id = self.vocab.get("</think>")

chenych's avatar
chenych committed
87
    def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
chenych's avatar
chenych committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        end_token_id = self.end_token_id
        return any(input_id == end_token_id for input_id in reversed(input_ids))

    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
        return input_ids

    def extract_reasoning_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
    ) -> DeltaMessage | None:
        if len(previous_token_ids) == 0:
            delta_text = "<think>" + delta_text
        return DeltaMessage(content=delta_text)

    def extract_reasoning(
chenych's avatar
chenych committed
108
        self, model_output: str, request: ChatCompletionRequest | ResponsesRequest
chenych's avatar
chenych committed
109
    ) -> tuple[str | None, str | None]:
chenych's avatar
chenych committed
110
        return None, "<think>" + model_output