Unverified Commit 5d93a950 authored by Yuhong Guo's avatar Yuhong Guo Committed by GitHub
Browse files

[BugFix] Fix combination of MTP and `--n-share-experts-fusion`with R1 (#5707)

parent c998d04b
......@@ -13,12 +13,14 @@
# ==============================================================================
"""Inference-only DeepSeek NextN Speculative Decoding."""
import logging
from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
......@@ -51,6 +53,9 @@ else:
from vllm._custom_ops import awq_dequantize
logger = logging.getLogger(__name__)
class DeepseekModelNextN(nn.Module):
def __init__(
self,
......@@ -134,7 +139,9 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
) -> None:
nn.Module.__init__(self)
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
self.determine_n_share_experts_fusion("DeepseekV3ForCausalLMNextN")
self.model = DeepseekModelNextN(
config, quant_config, prefix=add_prefix("model", prefix)
......@@ -182,6 +189,48 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
if self.n_share_experts_fusion > 0:
logger.info(
f"Cloning {self.n_share_experts_fusion} "
"replicas of the shared expert into MoE for DeepseekV3ForCausalLMNextN"
)
weights_list = list(weights)
weights_dict = dict(weights_list)
if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"up_proj.weight",
"up_proj.weight_scale",
]
else:
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale_inv",
"gate_proj.weight",
"gate_proj.weight_scale_inv",
"up_proj.weight",
"up_proj.weight_scale_inv",
]
names_to_remove = []
for num_repeat in range(self.n_share_experts_fusion):
for suffix in suffix_list:
shared_expert_weight_name = (
f"model.layers.0.mlp.shared_experts.{suffix}"
)
weights_list.append(
(
f"model.layers.0."
f"mlp.experts."
f"{self.config.n_routed_experts + num_repeat}"
f".{suffix}",
weights_dict[shared_expert_weight_name],
)
)
names_to_remove += [shared_expert_weight_name]
weights = [w for w in weights_list if w[0] not in names_to_remove]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
......@@ -190,7 +239,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
)
nextn_layer_prefix = "model.layers.0"
......
......@@ -1440,11 +1440,27 @@ class DeepseekV2ForCausalLM(nn.Module):
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
self.determine_n_share_experts_fusion()
self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
self.dp_size = get_attention_dp_size()
def determine_n_share_experts_fusion(
self, architecture: str = "DeepseekV3ForCausalLM"
):
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
if self.n_share_experts_fusion > 0:
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
if (
self.config.architectures[0] != "DeepseekV3ForCausalLM"
self.config.architectures[0] != architecture
or self.config.n_routed_experts != 256
):
self.n_share_experts_fusion = 0
......@@ -1459,7 +1475,7 @@ class DeepseekV2ForCausalLM(nn.Module):
elif self.n_share_experts_fusion == 0:
if (
torch.cuda.get_device_capability("cuda") >= (9, 0)
and self.config.architectures[0] == "DeepseekV3ForCausalLM"
and self.config.architectures[0] == architecture
and self.config.n_routed_experts == 256
and (not global_server_args_dict["enable_deepep_moe"])
):
......@@ -1469,18 +1485,6 @@ class DeepseekV2ForCausalLM(nn.Module):
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled."
)
self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
self.dp_size = get_attention_dp_size()
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment