"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "cd1b8d7ca8de01aa58476eb311a45091b423358e"
Unverified Commit 5d93a950 authored by Yuhong Guo's avatar Yuhong Guo Committed by GitHub
Browse files

[BugFix] Fix combination of MTP and `--n-share-experts-fusion`with R1 (#5707)

parent c998d04b
...@@ -13,12 +13,14 @@ ...@@ -13,12 +13,14 @@
# ============================================================================== # ==============================================================================
"""Inference-only DeepSeek NextN Speculative Decoding.""" """Inference-only DeepSeek NextN Speculative Decoding."""
import logging
from typing import Iterable, Optional, Tuple from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
...@@ -51,6 +53,9 @@ else: ...@@ -51,6 +53,9 @@ else:
from vllm._custom_ops import awq_dequantize from vllm._custom_ops import awq_dequantize
logger = logging.getLogger(__name__)
class DeepseekModelNextN(nn.Module): class DeepseekModelNextN(nn.Module):
def __init__( def __init__(
self, self,
...@@ -134,7 +139,9 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): ...@@ -134,7 +139,9 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
) -> None: ) -> None:
nn.Module.__init__(self) nn.Module.__init__(self)
self.config = config self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config self.quant_config = quant_config
self.determine_n_share_experts_fusion("DeepseekV3ForCausalLMNextN")
self.model = DeepseekModelNextN( self.model = DeepseekModelNextN(
config, quant_config, prefix=add_prefix("model", prefix) config, quant_config, prefix=add_prefix("model", prefix)
...@@ -182,6 +189,48 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): ...@@ -182,6 +189,48 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
if self.n_share_experts_fusion > 0:
logger.info(
f"Cloning {self.n_share_experts_fusion} "
"replicas of the shared expert into MoE for DeepseekV3ForCausalLMNextN"
)
weights_list = list(weights)
weights_dict = dict(weights_list)
if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"up_proj.weight",
"up_proj.weight_scale",
]
else:
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale_inv",
"gate_proj.weight",
"gate_proj.weight_scale_inv",
"up_proj.weight",
"up_proj.weight_scale_inv",
]
names_to_remove = []
for num_repeat in range(self.n_share_experts_fusion):
for suffix in suffix_list:
shared_expert_weight_name = (
f"model.layers.0.mlp.shared_experts.{suffix}"
)
weights_list.append(
(
f"model.layers.0."
f"mlp.experts."
f"{self.config.n_routed_experts + num_repeat}"
f".{suffix}",
weights_dict[shared_expert_weight_name],
)
)
names_to_remove += [shared_expert_weight_name]
weights = [w for w in weights_list if w[0] not in names_to_remove]
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
...@@ -190,7 +239,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): ...@@ -190,7 +239,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts, num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
) )
nextn_layer_prefix = "model.layers.0" nextn_layer_prefix = "model.layers.0"
......
...@@ -1440,11 +1440,27 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1440,11 +1440,27 @@ class DeepseekV2ForCausalLM(nn.Module):
self.config = config self.config = config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config self.quant_config = quant_config
self.determine_n_share_experts_fusion()
self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
self.dp_size = get_attention_dp_size()
def determine_n_share_experts_fusion(
self, architecture: str = "DeepseekV3ForCausalLM"
):
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"] self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
if self.n_share_experts_fusion > 0: if self.n_share_experts_fusion > 0:
# Only Deepseek V3/R1 can use shared experts fusion optimization now. # Only Deepseek V3/R1 can use shared experts fusion optimization now.
if ( if (
self.config.architectures[0] != "DeepseekV3ForCausalLM" self.config.architectures[0] != architecture
or self.config.n_routed_experts != 256 or self.config.n_routed_experts != 256
): ):
self.n_share_experts_fusion = 0 self.n_share_experts_fusion = 0
...@@ -1459,7 +1475,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1459,7 +1475,7 @@ class DeepseekV2ForCausalLM(nn.Module):
elif self.n_share_experts_fusion == 0: elif self.n_share_experts_fusion == 0:
if ( if (
torch.cuda.get_device_capability("cuda") >= (9, 0) torch.cuda.get_device_capability("cuda") >= (9, 0)
and self.config.architectures[0] == "DeepseekV3ForCausalLM" and self.config.architectures[0] == architecture
and self.config.n_routed_experts == 256 and self.config.n_routed_experts == 256
and (not global_server_args_dict["enable_deepep_moe"]) and (not global_server_args_dict["enable_deepep_moe"])
): ):
...@@ -1469,18 +1485,6 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1469,18 +1485,6 @@ class DeepseekV2ForCausalLM(nn.Module):
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled." "Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled."
) )
self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
self.dp_size = get_attention_dp_size()
def get_input_embeddings(self) -> nn.Embedding: def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens return self.model.embed_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment