Commit f4776ec3 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'minimax_m2' into 'v0.11.0-dev'

Add minimax_m2

See merge request dcutoolkit/deeplearing/vllm!258
parents e712dcbb 7636d436
...@@ -25,6 +25,7 @@ from .qwen3xml_tool_parser import Qwen3XMLToolParser ...@@ -25,6 +25,7 @@ from .qwen3xml_tool_parser import Qwen3XMLToolParser
from .seed_oss_tool_parser import SeedOssToolParser from .seed_oss_tool_parser import SeedOssToolParser
from .step3_tool_parser import Step3ToolParser from .step3_tool_parser import Step3ToolParser
from .xlam_tool_parser import xLAMToolParser from .xlam_tool_parser import xLAMToolParser
from .minimax_m2_tool_parser import MinimaxM2ToolParser
__all__ = [ __all__ = [
"ToolParser", "ToolParser",
...@@ -52,4 +53,5 @@ __all__ = [ ...@@ -52,4 +53,5 @@ __all__ = [
"SeedOssToolParser", "SeedOssToolParser",
"Step3ToolParser", "Step3ToolParser",
"OpenAIToolParser", "OpenAIToolParser",
"MinimaxM2ToolParser",
] ]
This diff is collapsed.
This diff is collapsed.
...@@ -52,6 +52,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -52,6 +52,7 @@ _TEXT_GENERATION_MODELS = {
"MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
"MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
"MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
"MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
# baichuan-7b, upper case 'C' in the class name # baichuan-7b, upper case 'C' in the class name
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
# baichuan-13b, lower case 'c' in the class name # baichuan-13b, lower case 'c' in the class name
......
...@@ -12,6 +12,7 @@ from .mistral_reasoning_parser import MistralReasoningParser ...@@ -12,6 +12,7 @@ from .mistral_reasoning_parser import MistralReasoningParser
from .qwen3_reasoning_parser import Qwen3ReasoningParser from .qwen3_reasoning_parser import Qwen3ReasoningParser
from .seedoss_reasoning_parser import SeedOSSReasoningParser from .seedoss_reasoning_parser import SeedOSSReasoningParser
from .step3_reasoning_parser import Step3ReasoningParser from .step3_reasoning_parser import Step3ReasoningParser
from .minimax_m2_reasoning_parser import MiniMaxM2ReasoningParser
__all__ = [ __all__ = [
"ReasoningParser", "ReasoningParser",
...@@ -27,4 +28,5 @@ __all__ = [ ...@@ -27,4 +28,5 @@ __all__ = [
"Step3ReasoningParser", "Step3ReasoningParser",
"GptOssReasoningParser", "GptOssReasoningParser",
"SeedOSSReasoningParser", "SeedOSSReasoningParser",
"MiniMaxM2ReasoningParser",
] ]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaMessage,
)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
class MiniMaxM2ReasoningParser(BaseThinkingReasoningParser):
"""
Reasoning parser for MiniMax M2 model.
"""
@property
def start_token(self) -> str:
"""The token that starts reasoning content."""
return "<think>"
@property
def end_token(self) -> str:
"""The token that ends reasoning content."""
return "</think>"
class MiniMaxM2AppendThinkReasoningParser(ReasoningParser):
"""
Reasoning parser for MiniMax M2 model.
"""
def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
self.end_token_id = self.vocab.get("</think>")
def is_reasoning_end(self, input_ids: list[int]) -> bool:
end_token_id = self.end_token_id
return any(input_id == end_token_id for input_id in reversed(input_ids))
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
return input_ids
def extract_reasoning_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
if len(previous_token_ids) == 0:
delta_text = "<think>" + delta_text
return DeltaMessage(content=delta_text)
def extract_reasoning(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[str | None, str | None]:
return None, "<think>" + model_output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment