minimax_m2_reasoning_parser.py 2.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Sequence

from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    DeltaMessage,
    ResponsesRequest,
)
from vllm.logger import init_logger
12
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
13
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
14
from vllm.tokenizers import TokenizerLike
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

logger = init_logger(__name__)


class MiniMaxM2ReasoningParser(BaseThinkingReasoningParser):
    """
    Reasoning parser for MiniMax M2 model.
    """

    @property
    def start_token(self) -> str:
        """The token that starts reasoning content."""
        return "<think>"

    @property
    def end_token(self) -> str:
        """The token that ends reasoning content."""
        return "</think>"


class MiniMaxM2AppendThinkReasoningParser(ReasoningParser):
    """
    Reasoning parser for MiniMax M2 model.
    """

40
    def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
41
42
43
44
45
46
47
48
49
50
        super().__init__(tokenizer, *args, **kwargs)
        self.end_token_id = self.vocab.get("</think>")

    def is_reasoning_end(self, input_ids: list[int]) -> bool:
        end_token_id = self.end_token_id
        return any(input_id == end_token_id for input_id in reversed(input_ids))

    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
        return input_ids

51
    def extract_reasoning_streaming(
52
53
54
55
56
57
58
59
60
61
62
63
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
    ) -> DeltaMessage | None:
        if len(previous_token_ids) == 0:
            delta_text = "<think>" + delta_text
        return DeltaMessage(content=delta_text)

64
    def extract_reasoning(
65
66
67
        self, model_output: str, request: ChatCompletionRequest | ResponsesRequest
    ) -> tuple[str | None, str | None]:
        return None, "<think>" + model_output