Unverified Commit a6ae3af1 authored by ryang's avatar ryang Committed by GitHub
Browse files

Support XiaomiMiMo inference with mtp (#6059)

parent 0b07c4a9
...@@ -283,6 +283,60 @@ ...@@ -283,6 +283,60 @@
"terminate_process(server_process)" "terminate_process(server_process)"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Multi Token Prediction\n",
"\n",
"We support [MTP(Multi-Token Prediction)](https://arxiv.org/pdf/2404.19737) in SGLang by using speculative decoding. We use Xiaomi/MiMo-7B-RL model as example here (deepseek mtp usage refer to [deepseek doc](../references/deepseek.md#multi-token-prediction))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
" python3 -m sglang.launch_server --model-path XiaomiMiMo/MiMo-7B-RL --host 0.0.0.0 --trust-remote-code \\\n",
" --speculative-algorithm EAGLE --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 \\\n",
" --mem-fraction 0.5\n",
"\"\"\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"\n",
"url = f\"http://localhost:{port}/v1/chat/completions\"\n",
"\n",
"data = {\n",
" \"model\": \"XiaomiMiMo/MiMo-7B-RL\",\n",
" \"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of France?\"}],\n",
"}\n",
"\n",
"response = requests.post(url, json=data)\n",
"print_highlight(response.json())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
......
...@@ -73,6 +73,7 @@ class ModelConfig: ...@@ -73,6 +73,7 @@ class ModelConfig:
model_override_args=self.model_override_args, model_override_args=self.model_override_args,
**kwargs, **kwargs,
) )
self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_text_config = get_hf_text_config(self.hf_config)
self.attention_chunk_size = getattr( self.attention_chunk_size = getattr(
self.hf_text_config, "attention_chunk_size", None self.hf_text_config, "attention_chunk_size", None
...@@ -97,6 +98,8 @@ class ModelConfig: ...@@ -97,6 +98,8 @@ class ModelConfig:
): ):
self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN" self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
self.hf_config.architectures[0] = "MiMoMTP"
# Check model type # Check model type
self.is_generation = is_generation_model( self.is_generation = is_generation_model(
self.hf_config.architectures, is_embedding self.hf_config.architectures, is_embedding
......
...@@ -782,12 +782,15 @@ class ModelRunner: ...@@ -782,12 +782,15 @@ class ModelRunner:
distributed=get_world_group().world_size > 1, distributed=get_world_group().world_size > 1,
cpu_group=get_world_group().cpu_group, cpu_group=get_world_group().cpu_group,
) )
if self.use_mla_backend: if self.is_draft_worker:
num_layers = ( num_layers = getattr(
self.model_config.num_hidden_layers self.model_config.hf_config,
if not self.is_draft_worker "num_nextn_predict_layers",
else self.model_config.hf_config.num_nextn_predict_layers self.num_effective_layers,
) )
else:
num_layers = self.num_effective_layers
if self.use_mla_backend:
# FIXME: pipeline parallelism is not compatible with mla backend # FIXME: pipeline parallelism is not compatible with mla backend
assert self.pp_size == 1 assert self.pp_size == 1
cell_size = ( cell_size = (
...@@ -799,7 +802,7 @@ class ModelRunner: ...@@ -799,7 +802,7 @@ class ModelRunner:
cell_size = ( cell_size = (
self.model_config.get_num_kv_heads(get_attention_tp_size()) self.model_config.get_num_kv_heads(get_attention_tp_size())
* self.model_config.head_dim * self.model_config.head_dim
* self.num_effective_layers * num_layers
* 2 * 2
* torch._utils._element_size(self.kv_cache_dtype) * torch._utils._element_size(self.kv_cache_dtype)
) )
......
# Adapted from https://github.com/vllm-project/vllm/pull/17433/files and deepseek_nextn.py
from functools import partial
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.mimo import MiMoForCausalLM
from sglang.srt.models.qwen2 import (
Qwen2Attention,
Qwen2DecoderLayer,
Qwen2MLP,
Qwen2Model,
)
from sglang.srt.utils import add_prefix
class MiMoMultiTokenPredictorLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
prefix: str,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.token_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hidden_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.input_proj = nn.Linear(
config.hidden_size * 2, config.hidden_size, bias=False
)
self.mtp_block = Qwen2DecoderLayer(
config=config, quant_config=quant_config, prefix=prefix
)
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
# masking inputs at position 0, as not needed by MTP
hidden_states[positions == 0] = 0
hidden_states = self.input_proj(
torch.cat(
(
self.hidden_layernorm(forward_batch.spec_info.hidden_states),
self.token_layernorm(hidden_states),
),
dim=-1,
)
)
hidden_states, residual = self.mtp_block(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
residual=None,
)
hidden_states = residual + hidden_states
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
class MiMoMTP(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
self.model = MiMoMultiTokenPredictorLayer(
config,
prefix,
quant_config,
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
name = self.map_model_name_to_mtp_param_name(name)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "mtp_block" not in name:
break
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if "mtp_block" not in name and (
"embed_tokens" not in name
and "lm_head" not in name
and "token_layernorm" not in name
and "hidden_layernorm" not in name
and "input_proj" not in name
and "final_layernorm" not in name
):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def map_model_name_to_mtp_param_name(self, name: str) -> str:
import re
name_without_prefix = [
"token_layernorm",
"hidden_layernorm",
"input_proj",
"final_layernorm",
]
pattern = r"model.mtp_layers.(\d+)."
group = re.match(pattern, name)
if group is not None:
for sub_name in name_without_prefix:
if sub_name in name:
name = name.replace(group.group(), "model.")
return name
name = name.replace(group.group(), "model.mtp_block.")
return name
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
del self.lm_head.weight
self.model.embed_tokens.weight = embed
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()
EntryClass = MiMoMTP
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestMiMoMTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "XiaomiMiMo/MiMo-7B-RL"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--speculative-algorithm",
"EAGLE",
"--speculative-num-steps",
"1",
"--speculative-eagle-topk",
"1",
"--speculative-num-draft-tokens",
"2",
"--mem-fraction-static",
"0.5",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.7)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment