Commit 44181448 authored by chenych's avatar chenych
Browse files

Add minimax-m2.1 tool call and resoning parase in v0.11.0

parent 8348926e
...@@ -19,13 +19,16 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -19,13 +19,16 @@ from vllm.entrypoints.openai.protocol import (
) )
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
ToolParserManager
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
logger = init_logger(__name__)
@ToolParserManager.register_module("minimax-m2")
class MinimaxM2ToolParser(ToolParser): class MinimaxM2ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer) super().__init__(tokenizer)
...@@ -71,11 +74,43 @@ class MinimaxM2ToolParser(ToolParser): ...@@ -71,11 +74,43 @@ class MinimaxM2ToolParser(ToolParser):
self.tool_call_complete_regex = re.compile( self.tool_call_complete_regex = re.compile(
r"<minimax:tool_call>(.*?)</minimax:tool_call>", re.DOTALL r"<minimax:tool_call>(.*?)</minimax:tool_call>", re.DOTALL
) )
# Improved regex: capture only the name attribute value (quoted or unquoted)
# and ignore any additional attributes that may follow
self.invoke_complete_regex = re.compile( self.invoke_complete_regex = re.compile(
r"<invoke name=(.*?)</invoke>", re.DOTALL r"""
) <invoke\s+name= # Match tag start and name attribute key
( # Start Group 1: Name value
"[^"]+" # Double-quoted string
| # OR
'[^']+' # Single-quoted string
| # OR
[^\s>]+ # Unquoted value (no whitespace or >)
) # End Group 1
(?:\s+[^>]*)? # Optional: Extra attributes (ignored)
\s*> # Closing bracket of opening tag
(.*?) # Group 2: Content (non-greedy)
</invoke> # Closing tag
""",
re.VERBOSE | re.DOTALL,
)
# Improved regex for parameters: capture name attribute and content separately
# Handles cases where model may include description text in attributes
self.parameter_complete_regex = re.compile( self.parameter_complete_regex = re.compile(
r"<parameter name=(.*?)</parameter>", re.DOTALL r"""
<parameter\s+name= # Match tag start and name attribute key
( # Start Group 1: Name value
"[^"]+" # Double-quoted string
| # OR
'[^']+' # Single-quoted string
| # OR
[^\s>]+ # Unquoted value (no whitespace or >)
) # End Group 1
(?:\s+[^>]*)? # Optional: Extra attributes (ignored)
\s*> # Closing bracket of opening tag
(.*?) # Group 2: Content (non-greedy)
</parameter> # Closing tag
""",
re.VERBOSE | re.DOTALL,
) )
if not self.model_tokenizer: if not self.model_tokenizer:
...@@ -122,6 +157,8 @@ class MinimaxM2ToolParser(ToolParser): ...@@ -122,6 +157,8 @@ class MinimaxM2ToolParser(ToolParser):
self.streaming_request = None self.streaming_request = None
# Clear previous tool call history to avoid state pollution # Clear previous tool call history to avoid state pollution
self.prev_tool_call_arr.clear() self.prev_tool_call_arr.clear()
# Reset streamed args tracking
self.streamed_args_for_tool.clear()
def _extract_name(self, name_str: str) -> str: def _extract_name(self, name_str: str) -> str:
"""Extract name from quoted string.""" """Extract name from quoted string."""
...@@ -135,49 +172,231 @@ class MinimaxM2ToolParser(ToolParser): ...@@ -135,49 +172,231 @@ class MinimaxM2ToolParser(ToolParser):
return name_str[1:-1] return name_str[1:-1]
return name_str return name_str
def _parse_name_from_attributes(self, attr_section: str) -> str:
"""Helper to extract name from attribute section string.
Handles quoted and unquoted names, ignoring extra attributes."""
# Check for quoted name first
if attr_section.startswith('"'):
# Find closing quote
close_quote = attr_section.find('"', 1)
if close_quote != -1:
name_raw = attr_section[: close_quote + 1]
else:
name_raw = attr_section
elif attr_section.startswith("'"):
# Find closing single quote
close_quote = attr_section.find("'", 1)
if close_quote != -1:
name_raw = attr_section[: close_quote + 1]
else:
name_raw = attr_section
else:
# Unquoted name - take until first whitespace
space_idx = -1
for i, c in enumerate(attr_section):
if c.isspace():
space_idx = i
break
name_raw = attr_section[:space_idx] if space_idx != -1 else attr_section
return self._extract_name(name_raw)
def _convert_param_value(self, value: str, param_type: str) -> Any: def _convert_param_value(self, value: str, param_type: str) -> Any:
"""Convert parameter value to the correct type.""" """Convert parameter value to the correct type (legacy single-type version)."""
if value.lower() == "null": return self._convert_param_value_with_types(value, [param_type])
def _extract_types_from_schema(self, schema: Any) -> list[str]:
"""
Extract all possible types from a JSON schema definition.
Handles anyOf, oneOf, allOf, type arrays, and enum fields.
Args:
schema: The JSON schema definition for a parameter
Returns:
List of type strings (e.g., ["string", "integer", "null"])
"""
if schema is None:
return ["string"]
if not isinstance(schema, dict):
return ["string"]
types: set[str] = set()
# Handle direct "type" field
if "type" in schema:
type_value = schema["type"]
if isinstance(type_value, str):
types.add(type_value)
elif isinstance(type_value, list):
for t in type_value:
if isinstance(t, str):
types.add(t)
# Handle enum - infer types from enum values
if "enum" in schema and isinstance(schema["enum"], list) and schema["enum"]:
for value in schema["enum"]:
if value is None:
types.add("null")
elif isinstance(value, bool):
types.add("boolean")
elif isinstance(value, int):
types.add("integer")
elif isinstance(value, float):
types.add("number")
elif isinstance(value, str):
types.add("string")
elif isinstance(value, list):
types.add("array")
elif isinstance(value, dict):
types.add("object")
# Handle anyOf, oneOf, allOf - recursively extract types
for choice_field in ("anyOf", "oneOf", "allOf"):
if choice_field in schema and isinstance(schema[choice_field], list):
for choice in schema[choice_field]:
extracted = self._extract_types_from_schema(choice)
types.update(extracted)
# If no types found, default to string
if not types:
return ["string"]
return list(types)
def _convert_param_value_with_types(
self, value: str, param_types: list[str]
) -> Any:
"""
Convert parameter value to the correct type based on a list of possible types.
Tries each type in order until one succeeds.
Args:
value: The string value to convert
param_types: List of possible type strings
Returns:
The converted value
"""
# Check if the VALUE itself indicates null (not just if null is allowed)
if value.lower() in ("null", "none", "nil"):
return None return None
param_type = param_type.lower() # Normalize types
normalized_types = [t.lower() for t in param_types]
# Try each type in order of preference (most specific first, string as fallback)
# Priority: integer > number > boolean > object > array > string
type_priority = [
"integer",
"int",
"number",
"float",
"boolean",
"bool",
"object",
"array",
"string",
"str",
"text",
]
for param_type in type_priority:
if param_type not in normalized_types:
continue
if param_type in ["string", "str", "text"]: if param_type in ["string", "str", "text"]:
return value return value
elif param_type in ["integer", "int"]: elif param_type in ["integer", "int"]:
try: try:
return int(value) return int(value)
except (ValueError, TypeError): except (ValueError, TypeError):
return value continue
elif param_type in ["number", "float"]: elif param_type in ["number", "float"]:
try: try:
val = float(value) val = float(value)
return val if val != int(val) else int(val) return val if val != int(val) else int(val)
except (ValueError, TypeError): except (ValueError, TypeError):
return value continue
elif param_type in ["boolean", "bool"]: elif param_type in ["boolean", "bool"]:
return value.lower() in ["true", "1"] lower_val = value.lower().strip()
if lower_val in ["true", "1", "yes", "on"]:
return True
elif lower_val in ["false", "0", "no", "off"]:
return False
continue
elif param_type in ["object", "array"]: elif param_type in ["object", "array"]:
try: try:
return json.loads(value) return json.loads(value)
except json.JSONDecodeError: except json.JSONDecodeError:
return value continue
else:
# Try JSON parse first, fallback to string # Fallback: try JSON parse, then return as string
try: try:
return json.loads(value) return json.loads(value)
except json.JSONDecodeError: except json.JSONDecodeError:
return value return value
def _get_param_types_from_config(
self, param_name: str, param_config: dict
) -> list[str]:
"""
Get parameter types from parameter configuration.
Handles anyOf, oneOf, allOf, and direct type definitions.
Args:
param_name: The name of the parameter
param_config: The properties dict from the tool schema
Returns:
List of type strings
"""
if param_name not in param_config:
return ["string"]
param_schema = param_config[param_name]
if not isinstance(param_schema, dict):
return ["string"]
return self._extract_types_from_schema(param_schema)
def _parse_single_invoke( def _parse_single_invoke(
self, invoke_str: str, tools: list | None self, invoke_str: str, tools: list | None
) -> ToolCall | None: ) -> ToolCall | None:
"""Parse a single <invoke> block.""" """Parse a single <invoke> block.
# Extract function name
Args:
invoke_str: For legacy regex, this is the full content after
'<invoke name='. For new regex with groups, this is
a tuple of (name, content).
tools: List of available tools for type information.
Returns:
Parsed ToolCall or None if parsing fails.
"""
# Handle both old format (string) and new format (tuple from regex groups)
if isinstance(invoke_str, tuple):
# New regex format: (name_raw, content)
function_name = self._extract_name(invoke_str[0])
invoke_content = invoke_str[1] if len(invoke_str) > 1 else ""
else:
# Fallback for unexpected string input
# (should generally be tuple from regex)
# Try to extract similarly to tuple case
match = self.invoke_complete_regex.search(invoke_str)
if match:
function_name = self._extract_name(match.group(1))
invoke_content = match.group(2)
else:
# Basic fallback if regex doesn't match
name_match = re.search(r"^([^>]+)", invoke_str) name_match = re.search(r"^([^>]+)", invoke_str)
if not name_match: if not name_match:
return None return None
function_name = self._extract_name(name_match.group(1)) function_name = self._extract_name(name_match.group(1))
# Extract content after the closing '>'
content_match = re.search(r"^[^>]+>(.*)", invoke_str, re.DOTALL)
invoke_content = content_match.group(1) if content_match else ""
# Get parameter configuration # Get parameter configuration
param_config = {} param_config = {}
...@@ -193,29 +412,31 @@ class MinimaxM2ToolParser(ToolParser): ...@@ -193,29 +412,31 @@ class MinimaxM2ToolParser(ToolParser):
param_config = params["properties"] param_config = params["properties"]
break break
# Extract parameters # Extract parameters using the improved regex
param_dict = {} param_dict = {}
for match in self.parameter_complete_regex.findall(invoke_str): for match in self.parameter_complete_regex.findall(invoke_content):
param_match = re.search(r"^([^>]+)>(.*)", match, re.DOTALL) # match is now a tuple: (param_name_raw, param_value)
if param_match: if isinstance(match, tuple) and len(match) >= 2:
param_name = self._extract_name(match[0])
param_value = match[1].strip()
else:
# Fallback for unexpected format
param_match = re.search(r"^([^>]+)>(.*)", str(match), re.DOTALL)
if not param_match:
continue
param_name = self._extract_name(param_match.group(1)) param_name = self._extract_name(param_match.group(1))
param_value = param_match.group(2).strip() param_value = param_match.group(2).strip()
if param_value.startswith("\n"): if param_value.startswith("\n"):
param_value = param_value[1:] param_value = param_value[1:]
if param_value.endswith("\n"): if param_value.endswith("\n"):
param_value = param_value[:-1] param_value = param_value[:-1]
# Get parameter type # Get parameter types (supports anyOf/oneOf/allOf)
param_type = "string" param_type = self._get_param_types_from_config(param_name, param_config)
if (
param_name in param_config
and isinstance(param_config[param_name], dict)
and "type" in param_config[param_name]
):
param_type = param_config[param_name]["type"]
# Convert value # Convert value
param_dict[param_name] = self._convert_param_value( param_dict[param_name] = self._convert_param_value_with_types(
param_value, param_type param_value, param_type
) )
...@@ -403,13 +624,16 @@ class MinimaxM2ToolParser(ToolParser): ...@@ -403,13 +624,16 @@ class MinimaxM2ToolParser(ToolParser):
func_start = tool_text.find(self.invoke_start_prefix) + len( func_start = tool_text.find(self.invoke_start_prefix) + len(
self.invoke_start_prefix self.invoke_start_prefix
) )
# Find the end quote for the function name # Find the end of the opening tag
func_end = tool_text.find(">", func_start) func_end = tool_text.find(">", func_start)
if func_end != -1: if func_end != -1:
# Found complete function name # Found complete function name
function_name_raw = tool_text[func_start:func_end] # Handle cases where model may add extra attributes after name
self.current_function_name = self._extract_name(function_name_raw) attr_section = tool_text[func_start:func_end]
self.current_function_name = self._parse_name_from_attributes(
attr_section
)
self.current_tool_id = self._generate_tool_call_id() self.current_tool_id = self._generate_tool_call_id()
self.header_sent = True self.header_sent = True
self.in_function = True self.in_function = True
...@@ -421,9 +645,12 @@ class MinimaxM2ToolParser(ToolParser): ...@@ -421,9 +645,12 @@ class MinimaxM2ToolParser(ToolParser):
self.prev_tool_call_arr.append( self.prev_tool_call_arr.append(
{ {
"name": self.current_function_name, "name": self.current_function_name,
"arguments": "{}", # Placeholder, will be updated later "arguments": {}, # Placeholder, will be updated later
} }
) )
# Initialize streamed_args_for_tool for this tool call
if len(self.streamed_args_for_tool) <= self.current_tool_index:
self.streamed_args_for_tool.append("")
# Send header with function info # Send header with function info
return DeltaMessage( return DeltaMessage(
...@@ -445,6 +672,9 @@ class MinimaxM2ToolParser(ToolParser): ...@@ -445,6 +672,9 @@ class MinimaxM2ToolParser(ToolParser):
# Send opening brace if not sent yet # Send opening brace if not sent yet
if self.in_function and not self.json_started: if self.in_function and not self.json_started:
self.json_started = True self.json_started = True
# Update streamed_args_for_tool for opening brace
if self.current_tool_index < len(self.streamed_args_for_tool):
self.streamed_args_for_tool[self.current_tool_index] += "{"
return DeltaMessage( return DeltaMessage(
tool_calls=[ tool_calls=[
DeltaToolCall( DeltaToolCall(
...@@ -493,7 +723,7 @@ class MinimaxM2ToolParser(ToolParser): ...@@ -493,7 +723,7 @@ class MinimaxM2ToolParser(ToolParser):
args = parsed_tool.function.arguments args = parsed_tool.function.arguments
self.prev_tool_call_arr[self.current_tool_index][ self.prev_tool_call_arr[self.current_tool_index][
"arguments" "arguments"
] = args ] = json.loads(args)
except Exception: except Exception:
pass # Ignore parsing errors during streaming pass # Ignore parsing errors during streaming
...@@ -505,7 +735,9 @@ class MinimaxM2ToolParser(ToolParser): ...@@ -505,7 +735,9 @@ class MinimaxM2ToolParser(ToolParser):
) )
] ]
) )
# Update streamed_args_for_tool for closing brace
if self.current_tool_index < len(self.streamed_args_for_tool):
self.streamed_args_for_tool[self.current_tool_index] += "}"
# Reset state for next tool # Reset state for next tool
self.json_closed = True self.json_closed = True
self.in_function = False self.in_function = False
...@@ -542,9 +774,14 @@ class MinimaxM2ToolParser(ToolParser): ...@@ -542,9 +774,14 @@ class MinimaxM2ToolParser(ToolParser):
if ">" in remaining: if ">" in remaining:
# We have the complete parameter name # We have the complete parameter name
# Handle cases where model may add extra attributes after name
# e.g., <parameter name="cmd" description="(e.g. ls)">
name_end = remaining.find(">") name_end = remaining.find(">")
param_name_raw = remaining[:name_end] attr_section = remaining[:name_end]
self.current_param_name = self._extract_name(param_name_raw)
self.current_param_name = self._parse_name_from_attributes(
attr_section
)
# Find the parameter value # Find the parameter value
value_start = param_start + name_end + 1 value_start = param_start + name_end + 1
...@@ -583,7 +820,7 @@ class MinimaxM2ToolParser(ToolParser): ...@@ -583,7 +820,7 @@ class MinimaxM2ToolParser(ToolParser):
# Store raw value for later processing # Store raw value for later processing
self.accumulated_params[self.current_param_name] = param_value self.accumulated_params[self.current_param_name] = param_value
# Get parameter configuration for type conversion # Get parameter configuration with anyOf support
param_config = {} param_config = {}
if self.streaming_request and self.streaming_request.tools: if self.streaming_request and self.streaming_request.tools:
for tool in self.streaming_request.tools: for tool in self.streaming_request.tools:
...@@ -600,17 +837,12 @@ class MinimaxM2ToolParser(ToolParser): ...@@ -600,17 +837,12 @@ class MinimaxM2ToolParser(ToolParser):
param_config = params["properties"] param_config = params["properties"]
break break
# Get parameter type # Get parameter types (supports anyOf/oneOf/allOf)
param_type = "string" param_type = self._get_param_types_from_config(
if ( self.current_param_name, param_config
self.current_param_name in param_config )
and isinstance(param_config[self.current_param_name], dict)
and "type" in param_config[self.current_param_name]
):
param_type = param_config[self.current_param_name]["type"]
# Convert param value to appropriate type converted_value = self._convert_param_value_with_types(
converted_value = self._convert_param_value(
param_value, param_type param_value, param_type
) )
...@@ -630,7 +862,11 @@ class MinimaxM2ToolParser(ToolParser): ...@@ -630,7 +862,11 @@ class MinimaxM2ToolParser(ToolParser):
) )
self.param_count += 1 self.param_count += 1
# Update streamed_args_for_tool for this tool call
if self.current_tool_index < len(self.streamed_args_for_tool):
self.streamed_args_for_tool[self.current_tool_index] += (
json_fragment
)
return DeltaMessage( return DeltaMessage(
tool_calls=[ tool_calls=[
DeltaToolCall( DeltaToolCall(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment