# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json import uuid from collections.abc import Sequence from typing import Any import regex as re from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionRequest, ) from vllm.entrypoints.openai.engine.protocol import ( DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall, ) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike from vllm.tool_parsers.abstract_tool_parser import ( Tool, ToolParser, ) logger = init_logger(__name__) class MinimaxM2ToolParser(ToolParser): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None): super().__init__(tokenizer, tools) self.prev_tool_call_arr: list[dict] = [] # Sentinel tokens self.tool_call_start_token: str = "" self.tool_call_end_token: str = "" # Streaming state self.is_tool_call_started: bool = False self.current_tool_index: int = 0 # Regex patterns for complete parsing self.tool_call_complete_regex = re.compile( r"(.*?)", re.DOTALL ) self.invoke_complete_regex = re.compile( r"", re.DOTALL ) self.parameter_complete_regex = re.compile( r"", re.DOTALL ) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " "constructor during construction." ) self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None: raise RuntimeError( "MiniMax M2 Tool parser could not locate tool call start/end " "tokens in the tokenizer!" ) logger.debug( "vLLM Successfully import tool parser %s !", self.__class__.__name__ ) def _generate_tool_call_id(self) -> str: """Generate a unique tool call ID.""" return f"call_{uuid.uuid4().hex[:24]}" def _extract_name(self, name_str: str) -> str: """Extract name from quoted string.""" name_str = name_str.strip() if (name_str.startswith('"') and name_str.endswith('"')) or ( name_str.startswith("'") and name_str.endswith("'") ): return name_str[1:-1] return name_str def _extract_types_from_schema(self, schema: Any) -> list[str]: """ Extract all possible types from a JSON schema definition. Handles anyOf, oneOf, allOf, type arrays, and enum fields. Args: schema: The JSON schema definition for a parameter Returns: List of type strings (e.g., ["string", "integer", "null"]) """ if schema is None: return ["string"] if not isinstance(schema, dict): return ["string"] types: set[str] = set() # Handle direct "type" field if "type" in schema: type_value = schema["type"] if isinstance(type_value, str): types.add(type_value) elif isinstance(type_value, list): for t in type_value: if isinstance(t, str): types.add(t) # Handle enum - infer types from enum values if "enum" in schema and isinstance(schema["enum"], list) and schema["enum"]: for value in schema["enum"]: if value is None: types.add("null") elif isinstance(value, bool): types.add("boolean") elif isinstance(value, int): types.add("integer") elif isinstance(value, float): types.add("number") elif isinstance(value, str): types.add("string") elif isinstance(value, list): types.add("array") elif isinstance(value, dict): types.add("object") # Handle anyOf, oneOf, allOf - recursively extract types for choice_field in ("anyOf", "oneOf", "allOf"): if choice_field in schema and isinstance(schema[choice_field], list): for choice in schema[choice_field]: extracted = self._extract_types_from_schema(choice) types.update(extracted) # If no types found, default to string if not types: return ["string"] return list(types) def _convert_param_value_with_types( self, value: str, param_types: list[str] ) -> Any: """ Convert parameter value to the correct type based on a list of possible types. Tries each type in order until one succeeds. Args: value: The string value to convert param_types: List of possible type strings Returns: The converted value """ # Check if the VALUE itself indicates null (not just if null is allowed) if value.lower() in ("null", "none", "nil"): return None # Normalize types normalized_types = [t.lower() for t in param_types] # Try each type in order of preference (most specific first, string as fallback) # Priority: integer > number > boolean > object > array > string type_priority = [ "integer", "int", "number", "float", "boolean", "bool", "object", "array", "string", "str", "text", ] for param_type in type_priority: if param_type not in normalized_types: continue if param_type in ["string", "str", "text"]: return value elif param_type in ["integer", "int"]: try: return int(value) except (ValueError, TypeError): continue elif param_type in ["number", "float"]: try: val = float(value) return val if val != int(val) else int(val) except (ValueError, TypeError): continue elif param_type in ["boolean", "bool"]: lower_val = value.lower().strip() if lower_val in ["true", "1", "yes", "on"]: return True elif lower_val in ["false", "0", "no", "off"]: return False continue elif param_type in ["object", "array"]: try: return json.loads(value) except json.JSONDecodeError: continue # Fallback: try JSON parse, then return as string try: return json.loads(value) except json.JSONDecodeError: return value def _get_param_types_from_config( self, param_name: str, param_config: dict ) -> list[str]: """ Get parameter types from parameter configuration. Handles anyOf, oneOf, allOf, and direct type definitions. Args: param_name: The name of the parameter param_config: The properties dict from the tool schema Returns: List of type strings """ if param_name not in param_config: return ["string"] param_schema = param_config[param_name] if not isinstance(param_schema, dict): return ["string"] return self._extract_types_from_schema(param_schema) def _parse_single_invoke( self, invoke_str: str, tools: list | None ) -> ToolCall | None: """Parse a single block.""" # Extract function name name_match = re.search(r"^([^>]+)", invoke_str) if not name_match: return None function_name = self._extract_name(name_match.group(1)) # Get parameter configuration param_config = {} if tools: for tool in tools: if ( hasattr(tool, "function") and tool.function.name == function_name and hasattr(tool.function, "parameters") ): params = tool.function.parameters if isinstance(params, dict) and "properties" in params: param_config = params["properties"] break # Extract parameters param_dict = {} for match in self.parameter_complete_regex.findall(invoke_str): param_match = re.search(r"^([^>]+)>(.*)", match, re.DOTALL) if param_match: param_name = self._extract_name(param_match.group(1)) param_value = param_match.group(2).strip() # Get parameter types (supports anyOf/oneOf/allOf) param_type = self._get_param_types_from_config(param_name, param_config) # Convert value param_dict[param_name] = self._convert_param_value_with_types( param_value, param_type ) return ToolCall( type="function", function=FunctionCall( name=function_name, arguments=json.dumps(param_dict, ensure_ascii=False), ), ) def _extract_delta_tool_calls( self, current_text: str, request: ChatCompletionRequest | None, ) -> list[DeltaToolCall]: """Extract DeltaToolCalls from newly completed blocks. Tracks progress via ``current_tool_index`` so each block is extracted exactly once across successive streaming calls. """ complete_invokes = self.invoke_complete_regex.findall(current_text) delta_tool_calls: list[DeltaToolCall] = [] while len(complete_invokes) > self.current_tool_index: invoke_str = complete_invokes[self.current_tool_index] tool_call = self._parse_single_invoke( invoke_str, self.tools, ) if not tool_call: self.current_tool_index += 1 continue args_json = tool_call.function.arguments idx = self.current_tool_index self.current_tool_index += 1 self.prev_tool_call_arr.append( { "name": tool_call.function.name, "arguments": json.loads(args_json), } ) self.streamed_args_for_tool.append(args_json) delta_tool_calls.append( DeltaToolCall( index=idx, id=self._generate_tool_call_id(), function=DeltaFunctionCall( name=tool_call.function.name, arguments=args_json, ), type="function", ) ) return delta_tool_calls def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: """Extract tool calls from complete model output (non-streaming).""" # Quick check if self.tool_call_start_token not in model_output: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) try: tool_calls = [] # Find all complete tool_call blocks for tool_call_match in self.tool_call_complete_regex.findall(model_output): # Find all invokes within this tool_call for invoke_match in self.invoke_complete_regex.findall(tool_call_match): tool_call = self._parse_single_invoke(invoke_match, self.tools) if tool_call: tool_calls.append(tool_call) if not tool_calls: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) # Update prev_tool_call_arr self.prev_tool_call_arr.clear() for tool_call in tool_calls: self.prev_tool_call_arr.append( { "name": tool_call.function.name, "arguments": tool_call.function.arguments, } ) # Extract content before first tool call first_tool_idx = model_output.find(self.tool_call_start_token) content = model_output[:first_tool_idx] if first_tool_idx > 0 else None return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, content=content ) except Exception: logger.exception("Error extracting tool calls") return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], # pylint: disable=unused-argument current_token_ids: Sequence[int], # pylint: disable=unused-argument delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: """Extract tool calls from streaming model output. Uses a buffer-until-complete-invoke strategy: tokens are buffered until a complete ``...`` block is available, then parsed and emitted in one shot. """ start_in_text = self.tool_call_start_token in delta_text start_in_ids = self.tool_call_start_token_id in delta_token_ids tool_call_starting = start_in_text or start_in_ids # Reset state on new request (parser is reused) or new tool-call block. if not previous_text or tool_call_starting: self.current_tool_index = 0 self.prev_tool_call_arr.clear() self.streamed_args_for_tool.clear() self.is_tool_call_started = tool_call_starting # Pass through content before any tool call. if not self.is_tool_call_started: return DeltaMessage(content=delta_text) if delta_text else None # Capture content before the start token. content_before = None if start_in_text: before = delta_text[: delta_text.index(self.tool_call_start_token)] content_before = before or None # Extract newly completed blocks as DeltaToolCalls. delta_tool_calls = self._extract_delta_tool_calls(current_text, request) if delta_tool_calls or content_before: return DeltaMessage( content=content_before, tool_calls=delta_tool_calls, ) # EOS and both arrive as special tokens with # no decoded text. Return non-None for EOS so the serving framework # reaches the finish-reason handling path instead of skipping. if ( not delta_text and delta_token_ids and self.prev_tool_call_arr and self.tool_call_end_token_id not in delta_token_ids ): return DeltaMessage(content="") return None