Unverified Commit f6983f01 authored by liuchenbing2026's avatar liuchenbing2026 Committed by GitHub
Browse files

MiniMax-M2: add Eagle3 speculative decoding support (#37512)


Signed-off-by: default avatarliuchenbing <chenliumail@163.com>
Signed-off-by: default avatarliucb <liuchengbao_work@163.com>
Co-authored-by: default avatarliuchenbing <chenliumail@163.com>
parent 780ba374
...@@ -1246,6 +1246,12 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { ...@@ -1246,6 +1246,12 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
use_original_num_layers=True, use_original_num_layers=True,
max_model_len=10240, max_model_len=10240,
), ),
"Eagle3MiniMaxM2ForCausalLM": _HfExamplesInfo(
"MiniMaxAI/MiniMax-M2",
trust_remote_code=True,
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
tokenizer="MiniMaxAI/MiniMax-M2",
),
"EagleMistralLarge3ForCausalLM": _HfExamplesInfo( "EagleMistralLarge3ForCausalLM": _HfExamplesInfo(
"mistralai/Mistral-Large-3-675B-Instruct-2512", "mistralai/Mistral-Large-3-675B-Instruct-2512",
speculative_model="mistralai/Mistral-Large-3-675B-Instruct-2512-Eagle", speculative_model="mistralai/Mistral-Large-3-675B-Instruct-2512-Eagle",
......
...@@ -817,6 +817,7 @@ class SpeculativeConfig: ...@@ -817,6 +817,7 @@ class SpeculativeConfig:
"deepseek_v3", "deepseek_v3",
"kimi_k2", "kimi_k2",
"kimi_k25", "kimi_k25",
"minimax_m2",
] ]
if ( if (
self.method in ("eagle3", "extract_hidden_states", "dflash") self.method in ("eagle3", "extract_hidden_states", "dflash")
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
"""Inference-only MiniMaxM2 model.""" """Inference-only MiniMaxM2 model."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any from typing import Any
import torch import torch
...@@ -59,7 +60,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -59,7 +60,7 @@ from vllm.model_executor.model_loader.weight_utils import (
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import EagleModelMixin, SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
PPMissingLayer, PPMissingLayer,
...@@ -313,7 +314,7 @@ class MiniMaxM2DecoderLayer(nn.Module): ...@@ -313,7 +314,7 @@ class MiniMaxM2DecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class MiniMaxM2Model(nn.Module): class MiniMaxM2Model(nn.Module, EagleModelMixin):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
...@@ -366,7 +367,7 @@ class MiniMaxM2Model(nn.Module): ...@@ -366,7 +367,7 @@ class MiniMaxM2Model(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors: ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -378,14 +379,24 @@ class MiniMaxM2Model(nn.Module): ...@@ -378,14 +379,24 @@ class MiniMaxM2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer : self.end_layer]: aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
self._maybe_add_hidden_state(
aux_hidden_states, idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors( return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual} {"hidden_states": hidden_states, "residual": residual}
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
...@@ -496,7 +507,7 @@ class MiniMaxM2Model(nn.Module): ...@@ -496,7 +507,7 @@ class MiniMaxM2Model(nn.Module):
return loaded_params return loaded_params
class MiniMaxM2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): class MiniMaxM2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
......
...@@ -554,6 +554,7 @@ _SPECULATIVE_DECODING_MODELS = { ...@@ -554,6 +554,7 @@ _SPECULATIVE_DECODING_MODELS = {
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
"DFlashDraftModel": ("qwen3_dflash", "DFlashQwen3ForCausalLM"), "DFlashDraftModel": ("qwen3_dflash", "DFlashQwen3ForCausalLM"),
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"Eagle3MiniMaxM2ForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment