minimax_m2_reasoning_parser.py 4.08 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Sequence
5
from typing import TYPE_CHECKING
6

7
from vllm.entrypoints.openai.engine.protocol import (
8
    DeltaMessage,
9
)
10
from vllm.logger import init_logger
11
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
12
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
13
from vllm.tokenizers import TokenizerLike
14

15
16
17
18
if TYPE_CHECKING:
    from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
    from vllm.entrypoints.openai.responses.protocol import ResponsesRequest

19
20
21
22
23
24
logger = init_logger(__name__)


class MiniMaxM2ReasoningParser(BaseThinkingReasoningParser):
    """
    Reasoning parser for MiniMax M2 model.
25
26
27
28

    MiniMax M2 models don't generate <think> start token, only </think> end
    token. All content before </think> is reasoning, content after is the
    actual response.
29
30
31
32
33
34
35
36
37
38
39
40
    """

    @property
    def start_token(self) -> str:
        """The token that starts reasoning content."""
        return "<think>"

    @property
    def end_token(self) -> str:
        """The token that ends reasoning content."""
        return "</think>"

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    def extract_reasoning_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
    ) -> DeltaMessage | None:
        """
        Extract reasoning content from a delta message for streaming.

        MiniMax M2 models don't generate <think> start token, so we assume
        all content is reasoning until we encounter the </think> end token.
        """
        # Skip single end token
        if len(delta_token_ids) == 1 and delta_token_ids[0] == self.end_token_id:
            return None

        # Check if end token has already appeared in previous tokens
        # meaning we're past the reasoning phase
        if self.end_token_id in previous_token_ids:
            # We're past the reasoning phase, this is content
            return DeltaMessage(content=delta_text)

        # Check if end token is in delta tokens
        if self.end_token_id in delta_token_ids:
            # End token in delta, split reasoning and content
            end_index = delta_text.find(self.end_token)
            reasoning = delta_text[:end_index]
            content = delta_text[end_index + len(self.end_token) :]
            return DeltaMessage(
                reasoning=reasoning if reasoning else None,
                content=content if content else None,
            )

        # No end token yet, all content is reasoning
        return DeltaMessage(reasoning=delta_text)

80
81
82
83
84
85

class MiniMaxM2AppendThinkReasoningParser(ReasoningParser):
    """
    Reasoning parser for MiniMax M2 model.
    """

86
    def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
87
88
        super().__init__(tokenizer, *args, **kwargs)
        self.end_token_id = self.vocab.get("</think>")
89
        self.start_token_id = self.vocab.get("<think>")
90

91
    def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
92
        end_token_id = self.end_token_id
93
94
95
96
97
        start_token_id = self.start_token_id
        for input_id in reversed(input_ids):
            if input_id in (end_token_id, start_token_id):
                return input_id == end_token_id
        return False
98
99
100
101

    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
        return input_ids

102
    def extract_reasoning_streaming(
103
104
105
106
107
108
109
110
111
112
113
114
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
    ) -> DeltaMessage | None:
        if len(previous_token_ids) == 0:
            delta_text = "<think>" + delta_text
        return DeltaMessage(content=delta_text)

115
    def extract_reasoning(
116
        self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
117
118
    ) -> tuple[str | None, str | None]:
        return None, "<think>" + model_output