Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f4776ec3
Commit
f4776ec3
authored
Nov 17, 2025
by
zhuwenwen
Browse files
Merge branch 'minimax_m2' into 'v0.11.0-dev'
Add minimax_m2 See merge request dcutoolkit/deeplearing/vllm!258
parents
e712dcbb
7636d436
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
1266 additions
and
0 deletions
+1266
-0
vllm/entrypoints/openai/tool_parsers/__init__.py
vllm/entrypoints/openai/tool_parsers/__init__.py
+2
-0
vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py
...entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py
+643
-0
vllm/model_executor/models/minimax_m2.py
vllm/model_executor/models/minimax_m2.py
+552
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
vllm/reasoning/__init__.py
vllm/reasoning/__init__.py
+2
-0
vllm/reasoning/minimax_m2_reasoning_parser.py
vllm/reasoning/minimax_m2_reasoning_parser.py
+66
-0
No files found.
vllm/entrypoints/openai/tool_parsers/__init__.py
View file @
f4776ec3
...
...
@@ -25,6 +25,7 @@ from .qwen3xml_tool_parser import Qwen3XMLToolParser
from
.seed_oss_tool_parser
import
SeedOssToolParser
from
.step3_tool_parser
import
Step3ToolParser
from
.xlam_tool_parser
import
xLAMToolParser
from
.minimax_m2_tool_parser
import
MinimaxM2ToolParser
__all__
=
[
"ToolParser"
,
...
...
@@ -52,4 +53,5 @@ __all__ = [
"SeedOssToolParser"
,
"Step3ToolParser"
,
"OpenAIToolParser"
,
"MinimaxM2ToolParser"
,
]
vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py
0 → 100644
View file @
f4776ec3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
import
uuid
from
collections.abc
import
Sequence
from
typing
import
Any
import
regex
as
re
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
ExtractedToolCallInformation
,
FunctionCall
,
ToolCall
,
)
from
vllm.entrypoints.openai.tool_parsers.abstract_tool_parser
import
(
ToolParser
,
)
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
logger
=
init_logger
(
__name__
)
class
MinimaxM2ToolParser
(
ToolParser
):
def
__init__
(
self
,
tokenizer
:
AnyTokenizer
):
super
().
__init__
(
tokenizer
)
self
.
prev_tool_call_arr
:
list
[
dict
]
=
[]
# Sentinel tokens
self
.
tool_call_start_token
:
str
=
"<minimax:tool_call>"
self
.
tool_call_end_token
:
str
=
"</minimax:tool_call>"
self
.
invoke_start_prefix
:
str
=
"<invoke name="
self
.
invoke_end_token
:
str
=
"</invoke>"
self
.
parameter_prefix
:
str
=
"<parameter name="
self
.
parameter_end_token
:
str
=
"</parameter>"
# Streaming state variables
self
.
current_tool_name_sent
:
bool
=
False
# Override base class type - we use string IDs for tool calls
self
.
current_tool_id
:
str
|
None
=
None
# type: ignore
self
.
streamed_args_for_tool
:
list
[
str
]
=
[]
self
.
is_tool_call_started
:
bool
=
False
self
.
failed_count
:
int
=
0
# Initialize streaming state variables
self
.
current_tool_index
:
int
=
0
self
.
invoke_index
:
int
=
0
self
.
header_sent
:
bool
=
False
self
.
current_function_name
:
str
|
None
=
None
self
.
current_param_name
:
str
|
None
=
None
self
.
current_param_value
:
str
=
""
self
.
param_count
:
int
=
0
self
.
in_param
:
bool
=
False
self
.
in_function
:
bool
=
False
self
.
accumulated_text
:
str
=
""
self
.
json_started
:
bool
=
False
self
.
json_closed
:
bool
=
False
self
.
accumulated_params
:
dict
=
{}
self
.
streaming_request
:
ChatCompletionRequest
|
None
=
None
# Enhanced streaming state - reset for each new message
self
.
_reset_streaming_state
()
# Regex patterns for complete parsing
self
.
tool_call_complete_regex
=
re
.
compile
(
r
"<minimax:tool_call>(.*?)</minimax:tool_call>"
,
re
.
DOTALL
)
self
.
invoke_complete_regex
=
re
.
compile
(
r
"<invoke name=(.*?)</invoke>"
,
re
.
DOTALL
)
self
.
parameter_complete_regex
=
re
.
compile
(
r
"<parameter name=(.*?)</parameter>"
,
re
.
DOTALL
)
if
not
self
.
model_tokenizer
:
raise
ValueError
(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
self
.
tool_call_start_token_id
=
self
.
vocab
.
get
(
self
.
tool_call_start_token
)
self
.
tool_call_end_token_id
=
self
.
vocab
.
get
(
self
.
tool_call_end_token
)
if
self
.
tool_call_start_token_id
is
None
or
self
.
tool_call_end_token_id
is
None
:
raise
RuntimeError
(
"MiniMax M2 Tool parser could not locate tool call start/end "
"tokens in the tokenizer!"
)
logger
.
debug
(
"vLLM Successfully import tool parser %s !"
,
self
.
__class__
.
__name__
)
def
_generate_tool_call_id
(
self
)
->
str
:
"""Generate a unique tool call ID."""
return
f
"call_
{
uuid
.
uuid4
().
hex
[:
24
]
}
"
def
_reset_streaming_state
(
self
):
"""Reset all streaming state."""
self
.
current_tool_index
=
0
self
.
invoke_index
=
0
self
.
is_tool_call_started
=
False
self
.
header_sent
=
False
self
.
current_tool_id
=
None
self
.
current_function_name
=
None
self
.
current_param_name
=
None
self
.
current_param_value
=
""
self
.
param_count
=
0
self
.
in_param
=
False
self
.
in_function
=
False
self
.
accumulated_text
=
""
self
.
json_started
=
False
self
.
json_closed
=
False
# Store accumulated parameters for type conversion
self
.
accumulated_params
=
{}
self
.
streaming_request
=
None
# Clear previous tool call history to avoid state pollution
self
.
prev_tool_call_arr
.
clear
()
def
_extract_name
(
self
,
name_str
:
str
)
->
str
:
"""Extract name from quoted string."""
name_str
=
name_str
.
strip
()
if
(
name_str
.
startswith
(
'"'
)
and
name_str
.
endswith
(
'"'
)
or
name_str
.
startswith
(
"'"
)
and
name_str
.
endswith
(
"'"
)
):
return
name_str
[
1
:
-
1
]
return
name_str
def
_convert_param_value
(
self
,
value
:
str
,
param_type
:
str
)
->
Any
:
"""Convert parameter value to the correct type."""
if
value
.
lower
()
==
"null"
:
return
None
param_type
=
param_type
.
lower
()
if
param_type
in
[
"string"
,
"str"
,
"text"
]:
return
value
elif
param_type
in
[
"integer"
,
"int"
]:
try
:
return
int
(
value
)
except
(
ValueError
,
TypeError
):
return
value
elif
param_type
in
[
"number"
,
"float"
]:
try
:
val
=
float
(
value
)
return
val
if
val
!=
int
(
val
)
else
int
(
val
)
except
(
ValueError
,
TypeError
):
return
value
elif
param_type
in
[
"boolean"
,
"bool"
]:
return
value
.
lower
()
in
[
"true"
,
"1"
]
elif
param_type
in
[
"object"
,
"array"
]:
try
:
return
json
.
loads
(
value
)
except
json
.
JSONDecodeError
:
return
value
else
:
# Try JSON parse first, fallback to string
try
:
return
json
.
loads
(
value
)
except
json
.
JSONDecodeError
:
return
value
def
_parse_single_invoke
(
self
,
invoke_str
:
str
,
tools
:
list
|
None
)
->
ToolCall
|
None
:
"""Parse a single <invoke> block."""
# Extract function name
name_match
=
re
.
search
(
r
"^([^>]+)"
,
invoke_str
)
if
not
name_match
:
return
None
function_name
=
self
.
_extract_name
(
name_match
.
group
(
1
))
# Get parameter configuration
param_config
=
{}
if
tools
:
for
tool
in
tools
:
if
(
hasattr
(
tool
,
"function"
)
and
tool
.
function
.
name
==
function_name
and
hasattr
(
tool
.
function
,
"parameters"
)
):
params
=
tool
.
function
.
parameters
if
isinstance
(
params
,
dict
)
and
"properties"
in
params
:
param_config
=
params
[
"properties"
]
break
# Extract parameters
param_dict
=
{}
for
match
in
self
.
parameter_complete_regex
.
findall
(
invoke_str
):
param_match
=
re
.
search
(
r
"^([^>]+)>(.*)"
,
match
,
re
.
DOTALL
)
if
param_match
:
param_name
=
self
.
_extract_name
(
param_match
.
group
(
1
))
param_value
=
param_match
.
group
(
2
).
strip
()
if
param_value
.
startswith
(
"
\n
"
):
param_value
=
param_value
[
1
:]
if
param_value
.
endswith
(
"
\n
"
):
param_value
=
param_value
[:
-
1
]
# Get parameter type
param_type
=
"string"
if
(
param_name
in
param_config
and
isinstance
(
param_config
[
param_name
],
dict
)
and
"type"
in
param_config
[
param_name
]
):
param_type
=
param_config
[
param_name
][
"type"
]
# Convert value
param_dict
[
param_name
]
=
self
.
_convert_param_value
(
param_value
,
param_type
)
return
ToolCall
(
type
=
"function"
,
function
=
FunctionCall
(
name
=
function_name
,
arguments
=
json
.
dumps
(
param_dict
,
ensure_ascii
=
False
),
),
)
def
extract_tool_calls
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
,
)
->
ExtractedToolCallInformation
:
"""Extract tool calls from complete model output (non-streaming)."""
# Quick check
if
self
.
tool_call_start_token
not
in
model_output
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
try
:
tool_calls
=
[]
# Find all complete tool_call blocks
for
tool_call_match
in
self
.
tool_call_complete_regex
.
findall
(
model_output
):
# Find all invokes within this tool_call
for
invoke_match
in
self
.
invoke_complete_regex
.
findall
(
tool_call_match
):
tool_call
=
self
.
_parse_single_invoke
(
invoke_match
,
request
.
tools
if
request
else
None
)
if
tool_call
:
tool_calls
.
append
(
tool_call
)
if
not
tool_calls
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
# Update prev_tool_call_arr
self
.
prev_tool_call_arr
.
clear
()
for
tool_call
in
tool_calls
:
self
.
prev_tool_call_arr
.
append
(
{
"name"
:
tool_call
.
function
.
name
,
"arguments"
:
tool_call
.
function
.
arguments
,
}
)
# Extract content before first tool call
first_tool_idx
=
model_output
.
find
(
self
.
tool_call_start_token
)
content
=
model_output
[:
first_tool_idx
]
if
first_tool_idx
>
0
else
None
return
ExtractedToolCallInformation
(
tools_called
=
True
,
tool_calls
=
tool_calls
,
content
=
content
)
except
Exception
:
logger
.
exception
(
"Error extracting tool calls"
)
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
def
extract_tool_calls_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
# pylint: disable=unused-argument
current_token_ids
:
Sequence
[
int
],
# pylint: disable=unused-argument
delta_token_ids
:
Sequence
[
int
],
request
:
ChatCompletionRequest
,
)
->
DeltaMessage
|
None
:
"""Extract tool calls from streaming model output."""
# Store request for type conversion
if
not
previous_text
or
self
.
tool_call_start_token
in
delta_text
:
self
.
_reset_streaming_state
()
self
.
streaming_request
=
request
# If no delta text, return None unless it's an EOS token after tools
if
not
delta_text
:
# Check if this is an EOS token after all tool calls are complete
if
delta_token_ids
and
self
.
tool_call_end_token_id
not
in
delta_token_ids
:
# Count complete tool calls
complete_calls
=
len
(
self
.
tool_call_complete_regex
.
findall
(
current_text
)
)
# If we have completed tool calls and populated prev_tool_call_arr
if
complete_calls
>
0
and
len
(
self
.
prev_tool_call_arr
)
>
0
:
# Check if all tool calls are closed
open_calls
=
current_text
.
count
(
self
.
tool_call_start_token
)
-
current_text
.
count
(
self
.
tool_call_end_token
)
if
open_calls
==
0
:
# Return empty delta for finish_reason processing
return
DeltaMessage
(
content
=
""
)
elif
not
self
.
is_tool_call_started
and
current_text
:
# This is a regular content response that's now complete
return
DeltaMessage
(
content
=
""
)
return
None
# Update accumulated text
self
.
accumulated_text
=
current_text
# Check if we need to advance to next tool
if
self
.
json_closed
and
not
self
.
in_function
:
# Check if this tool call has ended
invoke_ends
=
current_text
.
count
(
self
.
invoke_end_token
)
if
invoke_ends
>
self
.
current_tool_index
:
# This tool has ended, advance to next
self
.
current_tool_index
+=
1
self
.
header_sent
=
False
self
.
param_count
=
0
self
.
json_started
=
False
self
.
json_closed
=
False
self
.
in_function
=
False
# Now we can safely set this to False
self
.
accumulated_params
=
{}
# Continue processing next tool
return
None
# Handle normal content before tool calls
if
not
self
.
is_tool_call_started
:
# Check if tool call is starting
if
(
self
.
tool_call_start_token_id
in
delta_token_ids
or
self
.
tool_call_start_token
in
delta_text
):
self
.
is_tool_call_started
=
True
# Return any content before the tool call
if
self
.
tool_call_start_token
in
delta_text
:
content_before
=
delta_text
[
:
delta_text
.
index
(
self
.
tool_call_start_token
)
]
if
content_before
:
return
DeltaMessage
(
content
=
content_before
)
return
None
else
:
# Check if we're between tool calls - skip whitespace
if
(
current_text
.
rstrip
().
endswith
(
self
.
tool_call_end_token
)
and
delta_text
.
strip
()
==
""
):
# We just ended a tool call, skip whitespace
return
None
# Normal content, no tool call
return
DeltaMessage
(
content
=
delta_text
)
# Check if we're between tool calls (waiting for next one)
invoke_starts_count
=
current_text
.
count
(
self
.
invoke_start_prefix
)
if
self
.
current_tool_index
>=
invoke_starts_count
:
# We're past all tool calls, shouldn't be here
return
None
# Find the current tool call portion
invoke_start_positions
:
list
[
int
]
=
[]
idx
=
0
while
True
:
idx
=
current_text
.
find
(
self
.
invoke_start_prefix
,
idx
)
if
idx
==
-
1
:
break
invoke_start_positions
.
append
(
idx
)
idx
+=
len
(
self
.
invoke_start_prefix
)
if
self
.
current_tool_index
>=
len
(
invoke_start_positions
):
# No more tool calls to process yet
return
None
invoke_start_idx
=
invoke_start_positions
[
self
.
current_tool_index
]
# Find where this tool call ends (or current position if not ended yet)
invoke_end_idx
=
current_text
.
find
(
self
.
invoke_end_token
,
invoke_start_idx
)
if
invoke_end_idx
==
-
1
:
tool_text
=
current_text
[
invoke_start_idx
:]
else
:
tool_text
=
current_text
[
invoke_start_idx
:
invoke_end_idx
+
len
(
self
.
invoke_end_token
)
]
# Looking for function header
if
not
self
.
header_sent
:
if
self
.
invoke_start_prefix
in
tool_text
:
func_start
=
tool_text
.
find
(
self
.
invoke_start_prefix
)
+
len
(
self
.
invoke_start_prefix
)
# Find the end quote for the function name
func_end
=
tool_text
.
find
(
">"
,
func_start
)
if
func_end
!=
-
1
:
# Found complete function name
function_name_raw
=
tool_text
[
func_start
:
func_end
]
self
.
current_function_name
=
self
.
_extract_name
(
function_name_raw
)
self
.
current_tool_id
=
self
.
_generate_tool_call_id
()
self
.
header_sent
=
True
self
.
in_function
=
True
# Add to prev_tool_call_arr immediately when we detect a tool call
# Each tool call should be recorded regardless of function name
# Ensure we don't add the same tool call index multiple times
if
len
(
self
.
prev_tool_call_arr
)
<=
self
.
current_tool_index
:
self
.
prev_tool_call_arr
.
append
(
{
"name"
:
self
.
current_function_name
,
"arguments"
:
"{}"
,
# Placeholder, will be updated later
}
)
# Send header with function info
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_index
,
id
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
name
=
self
.
current_function_name
,
arguments
=
""
),
type
=
"function"
,
)
]
)
return
None
# We've sent header, now handle function body
if
self
.
in_function
:
# Send opening brace if not sent yet
if
self
.
in_function
and
not
self
.
json_started
:
self
.
json_started
=
True
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_index
,
function
=
DeltaFunctionCall
(
arguments
=
"{"
),
)
]
)
# Make sure json_started is set if we're processing parameters
if
not
self
.
json_started
:
self
.
json_started
=
True
# Check for function end in accumulated text
if
not
self
.
json_closed
and
self
.
invoke_end_token
in
tool_text
:
# Count total parameters in the tool text
total_param_count
=
tool_text
.
count
(
self
.
parameter_prefix
)
# Only close JSON if all parameters have been processed
if
self
.
param_count
>=
total_param_count
:
# Close JSON
self
.
json_closed
=
True
# Extract complete tool call
# Find the invoke content
invoke_start
=
tool_text
.
find
(
self
.
invoke_start_prefix
)
+
len
(
self
.
invoke_start_prefix
)
invoke_content_end
=
tool_text
.
find
(
self
.
invoke_end_token
,
invoke_start
)
if
invoke_content_end
!=
-
1
:
invoke_content
=
tool_text
[
invoke_start
:
invoke_content_end
]
# Parse to get the complete arguments
try
:
parsed_tool
=
self
.
_parse_single_invoke
(
invoke_content
,
self
.
streaming_request
.
tools
if
self
.
streaming_request
else
None
,
)
if
parsed_tool
and
self
.
current_tool_index
<
len
(
self
.
prev_tool_call_arr
):
# Update existing entry in prev_tool_call_arr
args
=
parsed_tool
.
function
.
arguments
self
.
prev_tool_call_arr
[
self
.
current_tool_index
][
"arguments"
]
=
args
except
Exception
:
pass
# Ignore parsing errors during streaming
result
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_index
,
function
=
DeltaFunctionCall
(
arguments
=
"}"
),
)
]
)
# Reset state for next tool
self
.
json_closed
=
True
self
.
in_function
=
False
self
.
accumulated_params
=
{}
logger
.
debug
(
"[M2_STREAMING] Tool call completed"
)
return
result
else
:
# Don't close JSON yet, continue processing parameters
return
None
# Look for parameters
# Find all parameter starts
param_starts
=
[]
idx
=
0
while
True
:
idx
=
tool_text
.
find
(
self
.
parameter_prefix
,
idx
)
if
idx
==
-
1
:
break
param_starts
.
append
(
idx
)
idx
+=
len
(
self
.
parameter_prefix
)
# Check if we should start a new parameter
if
(
not
self
.
in_param
and
self
.
param_count
<
len
(
param_starts
)
and
len
(
param_starts
)
>
self
.
param_count
):
# Process the next parameter
param_idx
=
param_starts
[
self
.
param_count
]
param_start
=
param_idx
+
len
(
self
.
parameter_prefix
)
remaining
=
tool_text
[
param_start
:]
if
">"
in
remaining
:
# We have the complete parameter name
name_end
=
remaining
.
find
(
">"
)
param_name_raw
=
remaining
[:
name_end
]
self
.
current_param_name
=
self
.
_extract_name
(
param_name_raw
)
# Find the parameter value
value_start
=
param_start
+
name_end
+
1
value_text
=
tool_text
[
value_start
:]
if
value_text
.
startswith
(
"
\n
"
):
value_text
=
value_text
[
1
:]
# Find where this parameter ends
param_end_idx
=
value_text
.
find
(
self
.
parameter_end_token
)
if
param_end_idx
==
-
1
:
# No closing tag, look for next parameter or function end
next_param_idx
=
value_text
.
find
(
self
.
parameter_prefix
)
func_end_idx
=
value_text
.
find
(
self
.
invoke_end_token
)
if
next_param_idx
!=
-
1
and
(
func_end_idx
==
-
1
or
next_param_idx
<
func_end_idx
):
param_end_idx
=
next_param_idx
elif
func_end_idx
!=
-
1
:
param_end_idx
=
func_end_idx
else
:
# Neither found, check if tool call is complete
if
self
.
invoke_end_token
in
tool_text
:
# Tool call and parameter is complete
param_end_idx
=
len
(
value_text
)
else
:
# Still streaming, wait for more content
return
None
if
param_end_idx
!=
-
1
:
# Complete parameter found
param_value
=
value_text
[:
param_end_idx
]
if
param_value
.
endswith
(
"
\n
"
):
param_value
=
param_value
[:
-
1
]
# Store raw value for later processing
self
.
accumulated_params
[
self
.
current_param_name
]
=
param_value
# Get parameter configuration for type conversion
param_config
=
{}
if
self
.
streaming_request
and
self
.
streaming_request
.
tools
:
for
tool
in
self
.
streaming_request
.
tools
:
if
(
hasattr
(
tool
,
"function"
)
and
tool
.
function
.
name
==
self
.
current_function_name
and
hasattr
(
tool
.
function
,
"parameters"
)
):
params
=
tool
.
function
.
parameters
if
(
isinstance
(
params
,
dict
)
and
"properties"
in
params
):
param_config
=
params
[
"properties"
]
break
# Get parameter type
param_type
=
"string"
if
(
self
.
current_param_name
in
param_config
and
isinstance
(
param_config
[
self
.
current_param_name
],
dict
)
and
"type"
in
param_config
[
self
.
current_param_name
]
):
param_type
=
param_config
[
self
.
current_param_name
][
"type"
]
# Convert param value to appropriate type
converted_value
=
self
.
_convert_param_value
(
param_value
,
param_type
)
# Build JSON fragment based on the converted type
# Use json.dumps to properly serialize the value
serialized_value
=
json
.
dumps
(
converted_value
,
ensure_ascii
=
False
)
if
self
.
param_count
==
0
:
json_fragment
=
(
f
'"
{
self
.
current_param_name
}
":
{
serialized_value
}
'
)
else
:
json_fragment
=
(
f
', "
{
self
.
current_param_name
}
":
{
serialized_value
}
'
)
self
.
param_count
+=
1
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_index
,
function
=
DeltaFunctionCall
(
arguments
=
json_fragment
),
)
]
)
return
None
vllm/model_executor/models/minimax_m2.py
0 → 100644
View file @
f4776ec3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The MiniMax AI team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniMaxM2 model."""
from
collections.abc
import
Iterable
from
typing
import
Any
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
# from vllm.compilation.decorators import support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.linear_attn
import
MiniMaxText01RMSNormTP
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
,
)
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
,
)
class
MiniMaxM2MoE
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
if
self
.
tp_size
>
config
.
num_local_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
config
.
num_local_experts
}
."
)
self
.
use_routing_bias
=
getattr
(
config
,
"use_routing_bias"
,
False
)
if
self
.
use_routing_bias
:
self
.
e_score_correction_bias
=
nn
.
Parameter
(
torch
.
empty
(
config
.
num_local_experts
,
dtype
=
torch
.
float32
)
)
self
.
e_score_correction_bias
.
weight_loader
=
(
MiniMaxM2MoE
.
ebias_weight_loader
)
else
:
self
.
e_score_correction_bias
=
None
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
scoring_func
=
config
.
scoring_func
,
use_grouped_topk
=
True
,
num_expert_group
=
1
,
topk_group
=
1
,
e_score_correction_bias
=
self
.
e_score_correction_bias
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
reduce_results
=
False
,
renormalize
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
,
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
num_local_experts
,
bias
=
False
,
params_dtype
=
torch
.
float32
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.gate"
,
)
@
staticmethod
def
ebias_weight_loader
(
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
assert
param
.
size
()
==
loaded_weight
.
size
()
param
.
data
.
copy_
(
loaded_weight
.
to
(
torch
.
float32
))
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
.
to
(
torch
.
float32
))
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
final_hidden_states
=
final_hidden_states
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
class
MiniMaxM2Attention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
rotary_dim
:
int
,
rope_theta
:
float
=
10000
,
rope_scaling
:
dict
[
str
,
Any
]
|
None
=
None
,
attn_window_size
:
int
|
None
=
None
,
max_position_embeddings
:
int
=
8192
,
head_dim
:
int
|
None
=
None
,
rms_norm_eps
:
float
=
1e-06
,
qkv_bias
:
bool
=
False
,
cache_config
:
CacheConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
head_dim
or
(
hidden_size
//
self
.
total_num_heads
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
qkv_bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
rotary_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
per_layer_sliding_window
=
attn_window_size
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
)
self
.
q_norm
=
MiniMaxText01RMSNormTP
(
self
.
head_dim
*
self
.
total_num_heads
,
eps
=
rms_norm_eps
)
self
.
k_norm
=
MiniMaxText01RMSNormTP
(
self
.
head_dim
*
self
.
total_num_kv_heads
,
eps
=
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
=
self
.
q_norm
(
q
)
k
=
self
.
k_norm
(
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
MiniMaxM2DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
prefix
:
str
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
if
hasattr
(
config
,
"max_model_len"
)
and
isinstance
(
config
.
max_model_len
,
int
):
max_position_embeddings
=
max
(
config
.
max_position_embeddings
,
config
.
max_model_len
)
# DecoderLayers are created with `make_layers` which passes the prefix
# with the layer's index.
layer_idx
=
int
(
prefix
.
split
(
sep
=
"."
)[
-
1
])
self
.
layer_idx
=
layer_idx
self
.
self_attn
=
MiniMaxM2Attention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
rotary_dim
=
config
.
rotary_dim
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
rms_norm_eps
=
config
.
rms_norm_eps
,
qkv_bias
=
getattr
(
config
,
"attention_bias"
,
False
),
head_dim
=
getattr
(
config
,
"head_dim"
,
None
),
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
self
.
block_sparse_moe
=
MiniMaxM2MoE
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
block_sparse_moe
(
hidden_states
)
return
hidden_states
,
residual
# @support_torch_compile
class
MiniMaxM2Model
(
nn
.
Module
):
fall_back_to_pt_during_load
=
False
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
model_config
=
vllm_config
.
model_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
vocab_size
=
config
.
vocab_size
if
get_pp_group
().
is_first_rank
:
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.embed_tokens"
,
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
MiniMaxM2DecoderLayer
(
config
,
prefix
,
model_config
=
model_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
if
get_pp_group
().
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
IntermediateTensors
:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
(
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
}
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"w1"
,
ckpt_down_proj_name
=
"w2"
,
ckpt_up_proj_name
=
"w3"
,
num_experts
=
self
.
config
.
num_local_experts
,
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
self
.
get_expert_mapping
()
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
spec_layer
=
get_spec_layer_idx_from_weight_name
(
self
.
config
,
name
)
if
spec_layer
is
not
None
:
continue
# skip spec decode layers for main model
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
if
weight_name
not
in
name
:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
(
"mlp.experts."
in
name
)
and
name
not
in
params_dict
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
MiniMaxM2ForCausalLM
(
nn
.
Module
,
SupportsPP
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
if
hasattr
(
vllm_config
.
model_config
,
"max_model_len"
):
self
.
config
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
model
=
MiniMaxM2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
)
)
if
get_pp_group
().
is_last_rank
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
None
)
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
|
IntermediateTensors
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
None
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
self
.
model
.
get_expert_mapping
()
def
get_spec_layer_idx_from_weight_name
(
config
:
PretrainedConfig
,
weight_name
:
str
)
->
int
|
None
:
if
hasattr
(
config
,
"num_mtp_modules"
)
and
(
config
.
num_mtp_modules
>
0
):
layer_idx
=
config
.
num_hidden_layers
for
i
in
range
(
config
.
num_mtp_modules
):
if
weight_name
.
startswith
(
f
"model.layers.
{
layer_idx
+
i
}
."
):
return
layer_idx
+
i
return
None
vllm/model_executor/models/registry.py
View file @
f4776ec3
...
...
@@ -52,6 +52,7 @@ _TEXT_GENERATION_MODELS = {
"MiniMaxForCausalLM"
:
(
"minimax_text_01"
,
"MiniMaxText01ForCausalLM"
),
"MiniMaxText01ForCausalLM"
:
(
"minimax_text_01"
,
"MiniMaxText01ForCausalLM"
),
"MiniMaxM1ForCausalLM"
:
(
"minimax_text_01"
,
"MiniMaxText01ForCausalLM"
),
"MiniMaxM2ForCausalLM"
:
(
"minimax_m2"
,
"MiniMaxM2ForCausalLM"
),
# baichuan-7b, upper case 'C' in the class name
"BaiChuanForCausalLM"
:
(
"baichuan"
,
"BaiChuanForCausalLM"
),
# baichuan-13b, lower case 'c' in the class name
...
...
vllm/reasoning/__init__.py
View file @
f4776ec3
...
...
@@ -12,6 +12,7 @@ from .mistral_reasoning_parser import MistralReasoningParser
from
.qwen3_reasoning_parser
import
Qwen3ReasoningParser
from
.seedoss_reasoning_parser
import
SeedOSSReasoningParser
from
.step3_reasoning_parser
import
Step3ReasoningParser
from
.minimax_m2_reasoning_parser
import
MiniMaxM2ReasoningParser
__all__
=
[
"ReasoningParser"
,
...
...
@@ -27,4 +28,5 @@ __all__ = [
"Step3ReasoningParser"
,
"GptOssReasoningParser"
,
"SeedOSSReasoningParser"
,
"MiniMaxM2ReasoningParser"
,
]
vllm/reasoning/minimax_m2_reasoning_parser.py
0 → 100644
View file @
f4776ec3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Sequence
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaMessage
,
)
from
vllm.logger
import
init_logger
from
vllm.reasoning
import
ReasoningParser
from
vllm.reasoning.basic_parsers
import
BaseThinkingReasoningParser
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
logger
=
init_logger
(
__name__
)
class
MiniMaxM2ReasoningParser
(
BaseThinkingReasoningParser
):
"""
Reasoning parser for MiniMax M2 model.
"""
@
property
def
start_token
(
self
)
->
str
:
"""The token that starts reasoning content."""
return
"<think>"
@
property
def
end_token
(
self
)
->
str
:
"""The token that ends reasoning content."""
return
"</think>"
class
MiniMaxM2AppendThinkReasoningParser
(
ReasoningParser
):
"""
Reasoning parser for MiniMax M2 model.
"""
def
__init__
(
self
,
tokenizer
:
AnyTokenizer
,
*
args
,
**
kwargs
):
super
().
__init__
(
tokenizer
,
*
args
,
**
kwargs
)
self
.
end_token_id
=
self
.
vocab
.
get
(
"</think>"
)
def
is_reasoning_end
(
self
,
input_ids
:
list
[
int
])
->
bool
:
end_token_id
=
self
.
end_token_id
return
any
(
input_id
==
end_token_id
for
input_id
in
reversed
(
input_ids
))
def
extract_content_ids
(
self
,
input_ids
:
list
[
int
])
->
list
[
int
]:
return
input_ids
def
extract_reasoning_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
)
->
DeltaMessage
|
None
:
if
len
(
previous_token_ids
)
==
0
:
delta_text
=
"<think>"
+
delta_text
return
DeltaMessage
(
content
=
delta_text
)
def
extract_reasoning
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
)
->
tuple
[
str
|
None
,
str
|
None
]:
return
None
,
"<think>"
+
model_output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment