Unverified Commit 7ebc28f5 authored by 赵晨阳's avatar 赵晨阳 Committed by GitHub
Browse files

[WIP] support MiniMax M2 model (#12129)


Signed-off-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Signed-off-by: default avatarxuebi <xuebi@minimaxi.com>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: default avatarXinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
Co-authored-by: default avatarRoger Young <42564206+rogeryoungh@users.noreply.github.com>
Co-authored-by: default avatarxuebi <xuebi@minimaxi.com>
parent b89111d6
...@@ -35,6 +35,7 @@ in the GitHub search bar. ...@@ -35,6 +35,7 @@ in the GitHub search bar.
| **MiniCPM** (v3, 4B) | `openbmb/MiniCPM3-4B` | OpenBMB’s series of compact LLMs for edge devices; MiniCPM 3 (4B) achieves GPT-3.5-level results in text tasks. | | **MiniCPM** (v3, 4B) | `openbmb/MiniCPM3-4B` | OpenBMB’s series of compact LLMs for edge devices; MiniCPM 3 (4B) achieves GPT-3.5-level results in text tasks. |
| **OLMo** (2, 3) | `allenai/OLMo-2-1124-7B-Instruct` | Allen AI’s series of Open Language Models designed to enable the science of language models. | | **OLMo** (2, 3) | `allenai/OLMo-2-1124-7B-Instruct` | Allen AI’s series of Open Language Models designed to enable the science of language models. |
| **OLMoE** (Open MoE) | `allenai/OLMoE-1B-7B-0924` | Allen AI’s open Mixture-of-Experts model (7B total, 1B active parameters) delivering state-of-the-art results with sparse expert activation. | | **OLMoE** (Open MoE) | `allenai/OLMoE-1B-7B-0924` | Allen AI’s open Mixture-of-Experts model (7B total, 1B active parameters) delivering state-of-the-art results with sparse expert activation. |
| **MiniMax-M2** | `minimax/MiniMax-M2` | MiniMax’s SOTA LLM for coding & agentic workflows. |
| **StableLM** (3B, 7B) | `stabilityai/stablelm-tuned-alpha-7b` | StabilityAI’s early open-source LLM (3B & 7B) for general text generation; a demonstration model with basic instruction-following ability. | | **StableLM** (3B, 7B) | `stabilityai/stablelm-tuned-alpha-7b` | StabilityAI’s early open-source LLM (3B & 7B) for general text generation; a demonstration model with basic instruction-following ability. |
| **Command-R** (Cohere) | `CohereForAI/c4ai-command-r-v01` | Cohere’s open conversational LLM (Command series) optimized for long context, retrieval-augmented generation, and tool use. | | **Command-R** (Cohere) | `CohereForAI/c4ai-command-r-v01` | Cohere’s open conversational LLM (Command series) optimized for long context, retrieval-augmented generation, and tool use. |
| **DBRX** (Databricks) | `databricks/dbrx-instruct` | Databricks’ 132B-parameter MoE model (36B active) trained on 12T tokens; competes with GPT-3.5 quality as a fully open foundation model. | | **DBRX** (Databricks) | `databricks/dbrx-instruct` | Databricks’ 132B-parameter MoE model (36B active) trained on 12T tokens; competes with GPT-3.5 quality as a fully open foundation model. |
......
...@@ -16,6 +16,7 @@ from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector ...@@ -16,6 +16,7 @@ from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector
from sglang.srt.function_call.gpt_oss_detector import GptOssDetector from sglang.srt.function_call.gpt_oss_detector import GptOssDetector
from sglang.srt.function_call.kimik2_detector import KimiK2Detector from sglang.srt.function_call.kimik2_detector import KimiK2Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.minimax_m2 import MinimaxM2Detector
from sglang.srt.function_call.mistral_detector import MistralDetector from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
...@@ -49,6 +50,7 @@ class FunctionCallParser: ...@@ -49,6 +50,7 @@ class FunctionCallParser:
"qwen25": Qwen25Detector, "qwen25": Qwen25Detector,
"qwen3_coder": Qwen3CoderDetector, "qwen3_coder": Qwen3CoderDetector,
"step3": Step3Detector, "step3": Step3Detector,
"minimax-m2": MinimaxM2Detector,
} }
def __init__(self, tools: List[Tool], tool_call_parser: str): def __init__(self, tools: List[Tool], tool_call_parser: str):
......
import ast
import html
import json
import logging
import re
from typing import Any, Dict, List, Tuple
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
ToolCallItem,
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
logger = logging.getLogger(__name__)
def _safe_val(raw: str) -> Any:
raw = html.unescape(raw.strip())
try:
return json.loads(raw)
except Exception:
try:
return ast.literal_eval(raw)
except Exception:
return raw
class MinimaxM2Detector(BaseFormatDetector):
"""
Detector for MiniMax M2 models.
Assumes function call format:
<minimax:tool_call>
<invoke name="func1">
<parameter name="param1">value1</parameter>
<parameter name="param2">value2</parameter>
</invoke>
</minimax:tool_call>
"""
def __init__(self):
super().__init__()
self.tool_call_start_token: str = "<minimax:tool_call>"
self.tool_call_end_token: str = "</minimax:tool_call>"
self.tool_call_prefix: str = '<invoke name="'
self.tool_call_function_end_token: str = "</invoke>"
self.tool_call_regex = re.compile(
r"<minimax:tool_call>(.*?)</minimax:tool_call>|<minimax:tool_call>(.*?)$",
re.DOTALL,
)
self.tool_call_function_regex = re.compile(
r"<invoke name=\"(.*?)</invoke>|<invoke name=\"(.*)$", re.DOTALL
)
self.tool_call_parameter_regex = re.compile(
r"<parameter name=\"(.*?)</parameter>|<parameter name=\"(.*?)$", re.DOTALL
)
self._buf: str = ""
# Streaming state variables
self._current_function_name: str = ""
self._current_parameters: Dict[str, Any] = {}
self._streamed_parameters: Dict[str, str] = (
{}
) # Track what parameter content we've streamed
self._in_tool_call: bool = False
self._function_name_sent: bool = False
def has_tool_call(self, text: str) -> bool:
return self.tool_call_start_token in text
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
normal, calls = self._extract(text, tools)
return StreamingParseResult(normal_text=normal, calls=calls)
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
self._buf += new_text
normal = ""
calls: List[ToolCallItem] = []
# Build tool indices for validation
if not hasattr(self, "_tool_indices"):
self._tool_indices = self._get_tool_indices(tools)
while True:
# If we're not in a tool call and don't see a start token, return normal text
if not self._in_tool_call and self.tool_call_start_token not in self._buf:
normal += self._buf
self._buf = ""
break
# Look for tool call start
if not self._in_tool_call:
s = self._buf.find(self.tool_call_start_token)
if s == -1:
normal += self._buf
self._buf = ""
break
normal += self._buf[:s]
self._buf = self._buf[s:]
self._in_tool_call = True
self._function_name_sent = False
self._current_function_name = ""
self._current_parameters = {}
self._streamed_parameters = {}
# Remove the start token
self._buf = self._buf[len(self.tool_call_start_token) :]
continue
# We're in a tool call, try to parse function name if not sent yet
if not self._function_name_sent:
# Look for function name pattern: <invoke name=name>
function_match = re.search(r"<invoke name=\"([^>]+)\">", self._buf)
if function_match:
function_name = function_match.group(1).strip()
# Validate function name
if function_name in self._tool_indices:
self._current_function_name = function_name
self._function_name_sent = True
# Initialize tool call tracking
if self.current_tool_id == -1:
self.current_tool_id = 0
# Ensure tracking arrays are large enough
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
while len(self.streamed_args_for_tool) <= self.current_tool_id:
self.streamed_args_for_tool.append("")
# Store tool call info
self.prev_tool_call_arr[self.current_tool_id] = {
"name": function_name,
"arguments": {},
}
# Send tool name with empty parameters
calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
name=function_name,
parameters="",
)
)
# Remove the processed function declaration
self._buf = self._buf[function_match.end() :]
continue
else:
# Invalid function name, reset state
logger.warning(f"Invalid function name: {function_name}")
self._reset_streaming_state()
normal += self._buf
self._buf = ""
break
else:
# Function name not complete yet, wait for more text
break
# Parse parameters incrementally
if self._function_name_sent:
# Process parameters and get any calls to emit
parameter_calls = self._parse_and_stream_parameters(self._buf)
calls.extend(parameter_calls)
# Check if tool call is complete
if self.tool_call_function_end_token in self._buf:
end_pos = self._buf.find(self.tool_call_function_end_token)
# Add closing brace to complete the JSON object
current_streamed = self.streamed_args_for_tool[self.current_tool_id]
if current_streamed:
# Count opening and closing braces to check if JSON is complete
open_braces = current_streamed.count("{")
close_braces = current_streamed.count("}")
if open_braces > close_braces:
calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
name=None,
parameters="}",
)
)
self.streamed_args_for_tool[self.current_tool_id] = (
current_streamed + "}"
)
# Complete the tool call
self._buf = self._buf[
end_pos + len(self.tool_call_function_end_token) :
]
self._reset_streaming_state(True)
self.current_tool_id += 1
continue
else:
# Tool call not complete yet, wait for more text
break
return StreamingParseResult(normal_text=normal, calls=calls)
def _parse_and_stream_parameters(self, text_to_parse: str) -> List[ToolCallItem]:
"""
Parse complete parameter blocks from text and return any tool call items to emit.
This method:
1. Finds all complete <parameter> blocks
2. Parses them into a dictionary
3. Compares with current parameters and generates diff if needed
4. Updates internal state
Args:
text_to_parse: The text to search for parameter blocks
Returns:
List of ToolCallItem objects to emit (may be empty)
"""
calls: List[ToolCallItem] = []
# Find all complete parameter patterns
param_matches = list(
re.finditer(
r"<parameter name=\"([^>]+)\">(.*?)</parameter>",
text_to_parse,
re.DOTALL,
)
)
# Build new parameters dictionary
new_params = {}
for match in param_matches:
param_name = match.group(1).strip()
param_value = match.group(2)
new_params[param_name] = _safe_val(param_value)
# Calculate parameter diff to stream with proper incremental JSON building
if new_params != self._current_parameters:
previous_args_json = self.streamed_args_for_tool[self.current_tool_id]
# Build incremental JSON properly
if not self._current_parameters:
# First parameter(s) - start JSON object but don't close it yet
items = []
for key, value in new_params.items():
items.append(
f"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}"
)
json_fragment = "{" + ", ".join(items)
calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
name=None,
parameters=json_fragment,
)
)
self.streamed_args_for_tool[self.current_tool_id] = json_fragment
else:
# Additional parameters - add them incrementally
new_keys = set(new_params.keys()) - set(self._current_parameters.keys())
if new_keys:
# Build the continuation part (no closing brace yet)
continuation_parts = []
for key in new_keys:
value = new_params[key]
continuation_parts.append(
f"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}"
)
json_fragment = ", " + ", ".join(continuation_parts)
calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
name=None,
parameters=json_fragment,
)
)
self.streamed_args_for_tool[self.current_tool_id] = (
previous_args_json + json_fragment
)
# Update current state
self._current_parameters = new_params
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = new_params
return calls
def _reset_streaming_state(self, still_in_tool_call: bool = False):
"""Reset streaming state for the next tool call"""
self._in_tool_call = still_in_tool_call
self._function_name_sent = False
self._current_function_name = ""
self._current_parameters = {}
self._streamed_parameters = {}
self.current_tool_name_sent = False
def _extract(self, text: str, tools: List[Tool]) -> Tuple[str, List[ToolCallItem]]:
normal_parts: List[str] = []
calls: List[ToolCallItem] = []
cursor = 0
while True:
s = text.find(self.tool_call_start_token, cursor)
if s == -1:
normal_parts.append(text[cursor:])
break
normal_parts.append(text[cursor:s])
e = text.find(self.tool_call_end_token, s)
if e == -1:
normal_parts.append(text[s:])
break
block = text[s : e + len(self.tool_call_end_token)]
cursor = e + len(self.tool_call_end_token)
calls.extend(self._parse_block(block, tools))
return "".join(normal_parts), calls
def _parse_block(self, block: str, tools: List[Tool]) -> List[ToolCallItem]:
res: List[ToolCallItem] = []
for m in self.tool_call_function_regex.findall(block):
txt = m[0] if m[0] else m[1]
if '">' not in txt:
continue
idx = txt.index('">')
fname = txt[:idx].strip()
body = txt[idx + 2 :]
params: Dict[str, Any] = {}
for pm in self.tool_call_parameter_regex.findall(body):
ptxt = pm[0] if pm[0] else pm[1]
if '">' not in ptxt:
continue
pidx = ptxt.index('">')
pname = ptxt[:pidx].strip()
pval = ptxt[pidx + 2 :].lstrip("\n").rstrip("\n")
params[pname] = _safe_val(pval)
raw = {"name": fname, "arguments": params}
try:
# TODO: fix idx in function call, the index for a function
# call will always be -1 in parse_base_json
res.extend(self.parse_base_json(raw, tools))
except Exception:
logger.warning("invalid tool call for %s dropped", fname)
return res
def supports_structural_tag(self) -> bool:
return False
def structure_info(self) -> _GetInfoFunc:
raise NotImplementedError
def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf(
tools,
individual_call_start_token=self.tool_call_start_token.replace("\n", "\\n"),
individual_call_end_token=self.tool_call_end_token.replace("\n", "\\n"),
tool_call_separator="\\n",
function_format="xml",
call_rule_fmt='"<invoke name=\\"{name}\\">\\n" {arguments_rule} "\\n</invoke>"',
key_value_rule_fmt='"<parameter name=\\"{key}\\">\\n" {valrule} "\\n</parameter>"',
key_value_separator='"\\n"',
)
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from DeepSeek and Mixtral implementation
"""Inference-only MiniMax M2 model compatible with HuggingFace weights."""
import logging
from typing import Iterable, Optional, Set, Tuple, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import (
LayerCommunicator,
LayerScatterModes,
ScatterMode,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import get_moe_a2a_backend
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
from sglang.srt.utils import (
BumpAllocator,
add_prefix,
get_compiler_backend,
is_non_idle_and_non_empty,
make_layers,
)
logger = logging.getLogger(__name__)
class MiniMaxM2RMSNormTP(nn.Module):
"""RMSNorm with Tensor Parallel support for QK normalization."""
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.tp_world = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
# Weight parameter is sharded across TP ranks
self.weight = nn.Parameter(torch.ones(int(hidden_size / self.tp_world)))
self.weight.weight_loader = self.weight_loader
self.variance_epsilon = eps
@staticmethod
def weight_loader(
param: nn.Parameter,
loaded_weight: torch.Tensor,
) -> None:
"""Custom weight loader that handles TP sharding."""
tp_world = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
shard_size = loaded_weight.shape[0] // tp_world
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
param.data.copy_(loaded_weight[shard])
@torch.compile(dynamic=True, backend=get_compiler_backend())
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Forward pass with TP-aware variance computation."""
assert residual is None, "RMSNormTP does not support residual connection."
orig_dtype = x.dtype
x = x.to(torch.float32)
# Compute variance across the full dimension (not just local shard)
variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
if self.tp_world > 1:
# All-reduce variance across TP ranks to get global variance
variance = tensor_model_parallel_all_reduce(variance) / self.tp_world
# Normalize and apply local weight shard
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
return x
class MiniMaxM2MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "mlp",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
)
self.act_fn = SiluAndMul()
return
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class MiniMaxM2MoE(nn.Module):
"""MiniMax MoE implementation using DeepEP for Expert Parallel support."""
def __init__(
self,
config: PretrainedConfig,
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > config.num_local_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_local_experts}."
)
self.use_routing_bias = getattr(config, "use_routing_bias", False)
if self.use_routing_bias:
self.e_score_correction_bias = nn.Parameter(
torch.empty(config.num_local_experts, dtype=torch.float32)
)
self.e_score_correction_bias.weight_loader = (
MiniMaxM2MoE.ebias_weight_loader
)
else:
self.e_score_correction_bias = None
self.experts = get_moe_impl_class(quant_config)(
num_experts=config.num_local_experts
+ get_global_server_args().ep_num_redundant_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("experts", prefix),
)
self.topk = TopK(
top_k=config.num_experts_per_tok,
renormalize=True,
scoring_func=config.scoring_func,
use_grouped_topk=True, # TODO: Use "grouped top-k" flag only for hardcoded sigmoid scoring
num_expert_group=1,
topk_group=1,
correction_bias=self.e_score_correction_bias,
routed_scaling_factor=1.0,
)
self.gate = ReplicatedLinear(
config.hidden_size,
config.num_local_experts,
bias=False,
params_dtype=torch.float32,
quant_config=None,
prefix=add_prefix("gate", prefix),
)
self.layer_id = layer_id
if get_moe_a2a_backend().is_deepep():
self.ep_size = get_moe_expert_parallel_world_size()
self.top_k = config.num_experts_per_tok
@staticmethod
def ebias_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor) -> None:
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight.to(torch.float32))
def forward(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor:
if get_moe_a2a_backend().is_deepep():
return self.forward_deepep(hidden_states, forward_batch)
else:
return self.forward_normal(hidden_states)
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states.to(torch.float32))
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
def forward_deepep(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor:
if hidden_states.shape[0] > 0:
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states.to(torch.float32))
topk_weights, topk_idx, _ = self.topk(
hidden_states,
router_logits,
num_token_non_padded=forward_batch.num_token_non_padded,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
else:
topk_weights, topk_idx, _ = self.topk.empty_topk_output(
hidden_states.shape[0], self.top_k
)
final_hidden_states = self.experts(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
)
return final_hidden_states
# TBO Operations for MiniMax MoE
def op_gate(self, state):
"""Gate operation for TBO - compute router logits"""
if is_non_idle_and_non_empty(
state.forward_batch.forward_mode, state.hidden_states_mlp_input
): # router_logits: (num_tokens, num_experts)
state.router_logits, _ = self.gate(state.hidden_states_mlp_input)
else:
state.router_logits = None
def op_select_experts(self, state):
"""Expert selection operation for TBO"""
router_logits = state.pop("router_logits")
hidden_states = state.hidden_states_mlp_input
if router_logits is not None:
with get_global_expert_distribution_recorder().with_current_layer(
self.layer_id
):
state.topk_weights_local, state.topk_idx_local, _ = self.topk(
hidden_states=hidden_states,
router_logits=router_logits,
num_token_non_padded=state.forward_batch.num_token_non_padded,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
else:
state.topk_idx_local = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
)
state.topk_weights_local = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
def op_dispatch_a(self, state):
"""Dispatch A operation for TBO - start async dispatch"""
if self.ep_size > 1:
self.experts.deepep_dispatcher.dispatch_a(
hidden_states=state.pop("hidden_states_mlp_input"),
topk_idx=state.pop("topk_idx_local"),
topk_weights=state.pop("topk_weights_local"),
forward_batch=state.forward_batch,
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
def op_dispatch_b(self, state):
"""Dispatch B operation for TBO - complete async dispatch"""
if self.ep_size > 1:
with get_global_expert_distribution_recorder().with_current_layer(
self.layer_id
):
state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
def op_experts(self, state):
"""Expert computation for TBO"""
state.hidden_states_experts_output = self.experts.moe_impl(
dispatch_output=state.dispatch_output,
)
def op_combine_a(self, state):
"""Combine A operation for TBO - start async combine"""
if self.ep_size > 1:
self.experts.deepep_dispatcher.combine_a(
hidden_states=state.pop("hidden_states_experts_output"),
topk_idx=state.dispatch_output.topk_idx,
topk_weights=state.dispatch_output.topk_weights,
forward_batch=state.forward_batch,
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
state.pop("dispatch_output")
def op_combine_b(self, state):
"""Combine B operation for TBO - complete async combine"""
if self.ep_size > 1:
state.hidden_states_after_combine = (
self.experts.deepep_dispatcher.combine_b(
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
)
def op_output(self, state):
"""Output operation for TBO - final MLP output"""
final_hidden_states = state.pop("hidden_states_after_combine")
# MiniMax doesn't have shared experts like DeepSeek, so no need to add them
state.hidden_states_mlp_output = final_hidden_states
class MiniMaxM2Attention(nn.Module):
"""MiniMax Attention implementation with QK normalization and partial RoPE."""
def __init__(
self,
config: PretrainedConfig,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
# Get dimensions from config
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = config.num_key_value_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# Use head_dim from config if available, otherwise calculate
self.head_dim = getattr(
config, "head_dim", self.hidden_size // self.total_num_heads
)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
# RoPE settings - support partial RoPE
self.rope_theta = getattr(config, "rope_theta", 10000)
self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.rotary_dim = getattr(
config, "rotary_dim", self.head_dim
) # MiniMax uses rotary_dim=64
# QK Normalization settings
self.use_qk_norm = getattr(config, "use_qk_norm", False)
self.qk_norm_type = getattr(config, "qk_norm_type", "per_layer")
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=False,
reduce_results=False,
quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
)
# Setup RoPE with partial rotary dimension
rope_scaling = getattr(config, "rope_scaling", None)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_dim, # Use partial rotary dimension
max_position=self.max_position_embeddings,
base=self.rope_theta,
rope_scaling=rope_scaling,
)
# QK Normalization layers
if self.use_qk_norm:
if self.qk_norm_type == "per_layer":
# Use RMSNormTP for proper tensor parallel support
# Use total dimensions (before TP sharding) for correct normalization
self.q_norm = MiniMaxM2RMSNormTP(
self.total_num_heads * self.head_dim, eps=config.rms_norm_eps
)
self.k_norm = MiniMaxM2RMSNormTP(
self.total_num_kv_heads * self.head_dim, eps=config.rms_norm_eps
)
else:
raise ValueError(f"Unsupported qk_norm_type: {self.qk_norm_type}")
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
def forward_prepare(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q = self.q_norm(q.contiguous())
k = self.k_norm(k.contiguous())
else:
q, k = q.contiguous(), k.contiguous()
q, k = self.rotary_emb(positions, q, k)
inner_state = q, k, v, forward_batch
return None, forward_batch, inner_state
def forward_core(self, intermediate_state):
_, _, inner_state = intermediate_state
attn_output = self.attn(*inner_state)
output, _ = self.o_proj(attn_output)
return output
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
s = self.forward_prepare(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
return self.forward_core(s)
def op_prepare(self, state):
state.attn_intermediate_state = self.forward_prepare(
positions=state.positions,
hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
forward_batch=state.forward_batch,
)
def op_core(self, state):
state.hidden_states_after_attn = self.forward_core(
state.pop("attn_intermediate_state")
)
class MiniMaxM2DecoderLayer(nn.Module):
"""MiniMax Decoder Layer implementation with MoE support."""
def __init__(
self,
config: PretrainedConfig,
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.layer_id = layer_id
# TBO support: All MiniMax layers are sparse (MoE)
self.is_layer_sparse = True
self.self_attn = MiniMaxM2Attention(
config=config,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
)
self.block_sparse_moe = MiniMaxM2MoE(
config=config,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
self.input_layernorm = RMSNorm(
config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-6)
)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-6)
)
is_previous_layer_sparse = True
self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id,
num_layers=config.num_hidden_layers,
is_layer_sparse=self.is_layer_sparse,
is_previous_layer_sparse=is_previous_layer_sparse,
)
self.layer_communicator = LayerCommunicator(
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=True,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch
)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
# Fully Connected (MLP or MoE)
hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states, residual, forward_batch
)
hidden_states = self.block_sparse_moe(hidden_states, forward_batch)
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
)
return hidden_states, residual
# TBO Operations for MiniMax Decoder Layer
def op_comm_prepare_attn(
self,
state,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
tbo_subbatch_index: Optional[int] = None,
):
"""Communication prepare for attention - TBO operation"""
state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
)
state.update(
dict(
forward_batch=forward_batch,
positions=positions,
zero_allocator=zero_allocator,
tbo_subbatch_index=tbo_subbatch_index,
)
)
def op_comm_prepare_mlp(self, state):
"""Communication prepare for MLP - TBO operation"""
state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
self.layer_communicator.prepare_mlp(
state.pop("hidden_states_after_attn"),
state.pop("residual_after_input_ln"),
state.forward_batch,
)
)
def op_mlp(self, state):
hidden_states = state.pop("hidden_states_mlp_input")
state.hidden_states_mlp_output = self.block_sparse_moe(
hidden_states, state.forward_batch
)
def op_comm_postprocess_layer(self, state):
"""Communication postprocess for layer - TBO operation"""
hidden_states, residual = self.layer_communicator.postprocess_layer(
state.pop("hidden_states_mlp_output"),
state.pop("residual_after_comm_pre_mlp"),
state.forward_batch,
)
output = dict(
positions=state.positions,
hidden_states=hidden_states,
residual=residual,
forward_batch=state.forward_batch,
zero_allocator=state.zero_allocator,
tbo_subbatch_index=state.tbo_subbatch_index,
)
return output
class MiniMaxM2Model(nn.Module):
"""MiniMax Model implementation."""
fall_back_to_pt_during_load = False
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.padding_idx = getattr(config, "pad_token_id", 0)
self.vocab_size = config.vocab_size
self.pp_group = get_pp_group()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
def layer_fn(idx, prefix: str) -> nn.Module:
return MiniMaxM2DecoderLayer(
config=config,
layer_id=idx,
quant_config=quant_config,
prefix=prefix,
)
self.layers, self.start_layer, self.end_layer = make_layers(
config.num_hidden_layers,
layer_fn,
pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size,
prefix=add_prefix("layers", prefix),
)
if self.pp_group.is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer(return_tuple=True)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[torch.Tensor, PPProxyTensors]:
if self.pp_group.is_first_rank:
if input_embeds is None:
hidden_states = self.get_input_embeddings(input_ids)
else:
hidden_states = input_embeds
residual = None
else:
assert pp_proxy_tensors is not None
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
if forward_batch.can_run_tbo:
hidden_states, residual = model_forward_maybe_tbo(
layers=self.layers,
enable_tbo=True,
input_data_scatter_mode=ScatterMode.model_input_output(),
positions=positions,
forward_batch=forward_batch,
hidden_states=hidden_states,
residual=residual,
)
else:
for i in range(self.start_layer, self.end_layer):
with get_global_expert_distribution_recorder().with_current_layer(i):
layer = self.layers[i]
hidden_states, residual = layer(
positions=positions,
forward_batch=forward_batch,
hidden_states=hidden_states,
residual=residual,
)
if not self.pp_group.is_last_rank:
return PPProxyTensors(
{"hidden_states": hidden_states, "residual": residual}
)
if residual is not None:
hidden_states, _ = self.norm(hidden_states, residual)
else:
hidden_states = self.norm(hidden_states)
return hidden_states
class MiniMaxM2ForCausalLM(nn.Module):
"""MiniMax M2 model for causal language modeling."""
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = MiniMaxM2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=None,
prefix=add_prefix("lm_head", prefix),
)
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
# _print_tensor_info(input_ids, "input_ids")
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""Load model weights with proper mapping for MiniMax architecture."""
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.num_local_experts,
)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
@classmethod
def get_model_config_for_expert_location(cls, config):
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
return ModelConfigForExpertLocation(
num_layers=config.num_hidden_layers,
num_logical_experts=config.num_local_experts,
num_groups=None,
)
def get_spec_layer_idx_from_weight_name(
config: PretrainedConfig, weight_name: str
) -> Optional[int]:
if hasattr(config, "num_mtp_modules") and (config.num_mtp_modules > 0):
layer_idx = config.num_hidden_layers
for i in range(config.num_mtp_modules):
if weight_name.startswith(f"model.layers.{layer_idx + i}."):
return layer_idx + i
return None
# Entry class for model registration
EntryClass = MiniMaxM2ForCausalLM
...@@ -249,6 +249,31 @@ class GptOssDetector(BaseReasoningFormatDetector): ...@@ -249,6 +249,31 @@ class GptOssDetector(BaseReasoningFormatDetector):
) )
class MiniMaxAppendThinkDetector(BaseReasoningFormatDetector):
"""
Append `<think>` token to the beginning of the text.
"""
def __init__(self, stream_reasoning: bool = True, force_reasoning: bool = False):
# scheduler.py need `reasoning_parser.detector.think_end_token`
super().__init__(
"<think>",
"</think>",
force_reasoning=force_reasoning,
stream_reasoning=stream_reasoning,
)
self.is_first_chunk = False
def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
if not self.is_first_chunk:
self.is_first_chunk = True
new_text = self.think_start_token + new_text
return StreamingParseResult(normal_text=new_text)
def detect_and_parse(self, text: str) -> StreamingParseResult:
return StreamingParseResult(normal_text=self.think_start_token + text)
class ReasoningParser: class ReasoningParser:
""" """
Parser that handles both streaming and non-streaming scenarios for extracting Parser that handles both streaming and non-streaming scenarios for extracting
...@@ -268,6 +293,8 @@ class ReasoningParser: ...@@ -268,6 +293,8 @@ class ReasoningParser:
"kimi": KimiDetector, "kimi": KimiDetector,
"qwen3": Qwen3Detector, "qwen3": Qwen3Detector,
"qwen3-thinking": Qwen3Detector, "qwen3-thinking": Qwen3Detector,
"minimax": Qwen3Detector,
"minimax-append-think": MiniMaxAppendThinkDetector,
"step3": DeepSeekR1Detector, "step3": DeepSeekR1Detector,
} }
...@@ -285,7 +312,7 @@ class ReasoningParser: ...@@ -285,7 +312,7 @@ class ReasoningParser:
raise ValueError(f"Unsupported model type: {model_type}") raise ValueError(f"Unsupported model type: {model_type}")
# Special cases where we override force_reasoning # Special cases where we override force_reasoning
if model_type.lower() in {"qwen3-thinking", "gpt-oss"}: if model_type.lower() in {"qwen3-thinking", "gpt-oss", "minimax"}:
force_reasoning = True force_reasoning = True
# Only pass force_reasoning if explicitly set, let detectors use their defaults # Only pass force_reasoning if explicitly set, let detectors use their defaults
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment