Unverified Commit 8005e606 authored by 杰兮's avatar 杰兮 Committed by GitHub
Browse files

[Bugfix][Rocm] Fix shared expert weight loading failure in DeepSeek-MTP (#27563)


Signed-off-by: default avatarzhyajie <yajizhan@amd.com>
Co-authored-by: default avatarzhyajie <yajizhan@amd.com>
parent 68dfe28e
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable import typing
from collections.abc import Callable, Iterable
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -231,6 +233,9 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts): ...@@ -231,6 +233,9 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
return self.model.compute_logits(hidden_states, spec_step_idx) return self.model.compute_logits(hidden_states, spec_step_idx)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
rocm_aiter_moe_shared_expert_enabled = (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
)
stacked_params_mapping = [ stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
...@@ -238,11 +243,16 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts): ...@@ -238,11 +243,16 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
] ]
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts, num_experts=self.config.n_routed_experts
+ (
self.config.n_shared_experts
if rocm_aiter_moe_shared_expert_enabled
else 0
),
) )
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
...@@ -253,6 +263,9 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts): ...@@ -253,6 +263,9 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None: if spec_layer is None:
continue continue
is_fusion_moe_shared_experts_layer = (
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
)
name = self._rewrite_spec_layer_name(spec_layer, name) name = self._rewrite_spec_layer_name(spec_layer, name)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below). # Skip non-stacked layers and experts (experts handled below).
...@@ -266,6 +279,8 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts): ...@@ -266,6 +279,8 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
# for mlp.experts[0].gate_gate_up_proj, which breaks load. # for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict: if ("mlp.experts." in name) and name not in params_dict:
continue continue
if is_fusion_moe_shared_experts_layer:
continue
name_mapped = name.replace(weight_name, param_name) name_mapped = name.replace(weight_name, param_name)
# QKV fusion is optional, fall back to normal # QKV fusion is optional, fall back to normal
...@@ -286,45 +301,105 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts): ...@@ -286,45 +301,105 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
for mapping in expert_params_mapping: # Special handling: when AITER fusion_shared_experts is enabled,
param_name, weight_name, expert_id, shard_id = mapping # checkpoints may provide a single widened shared_experts tensor
if weight_name not in name: # without explicit expert indices
continue # (e.g. ...mlp.shared_experts.gate_proj.weight).
name = name.replace(weight_name, param_name) # For models with multiple shared experts, split that tensor
# evenly into per-shared-expert slices and load them into
param = params_dict[name] # appended expert slots mlp.experts.{n_routed_experts + j}.*
weight_loader = param.weight_loader # accordingly.
weight_loader( num_chunks = 1
param, if is_fusion_moe_shared_experts_layer:
loaded_weight, num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
name, # Determine split axis based on op type
shard_id=shard_id, # gate/up: ColumnParallel → split along dim 0
expert_id=expert_id, # down: RowParallel → split along dim 1
) split_dim = 1 if "down_proj.weight" in name else 0
break total = loaded_weight.shape[split_dim]
else: assert total % num_chunks == 0, (
# Skip loading extra bias for GPTQ models. f"Shared expert weight dim {total} "
if name.endswith(".bias") and name not in params_dict: f"not divisible by num_chunks {num_chunks}"
continue
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
# According to DeepSeek-V3 Technical Report, MTP modules
# shares embedding layer. We only load the first weights.
if (
spec_layer != self.model.mtp_start_layer_idx
and ".layers" not in name
):
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
) )
weight_loader(param, loaded_weight) chunk_size = total // num_chunks
loaded_params.add(name)
for j in range(num_chunks):
chunk_name = name
weight_to_load = loaded_weight
if is_fusion_moe_shared_experts_layer:
if split_dim == 0:
weight_to_load = loaded_weight[
j * chunk_size : (j + 1) * chunk_size, :
]
else:
weight_to_load = loaded_weight[
:, j * chunk_size : (j + 1) * chunk_size
]
# Synthesize an expert-style name so expert mapping
# can route it
chunk_name = name.replace(
"mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts + j}",
)
# Use expert_params_mapping to locate the destination
# param and delegate to its expert-aware weight_loader
# with expert_id.
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in chunk_name:
continue
# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = chunk_name.replace(weight_name, param_name)
param = params_dict[name_mapped]
# We should ask the weight loader to return success or
# not here since otherwise we may skip experts with
# other available replicas.
weight_loader = typing.cast(
Callable[..., bool], param.weight_loader
)
success = weight_loader(
param,
weight_to_load,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
if not is_fusion_moe_shared_experts_layer:
name = name_mapped
else:
loaded_params.add(name_mapped)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
# According to DeepSeek-V3 Technical Report, MTP modules
# shares embedding layer. We only load the first weights.
if (
spec_layer != self.model.mtp_start_layer_idx
and ".layers" not in name
):
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
if not is_fusion_moe_shared_experts_layer:
loaded_params.add(name)
return loaded_params return loaded_params
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
......
...@@ -1479,8 +1479,8 @@ class DeepseekV2ForCausalLM( ...@@ -1479,8 +1479,8 @@ class DeepseekV2ForCausalLM(
if spec_layer is not None: if spec_layer is not None:
continue # skip spec decode layers for main model continue # skip spec decode layers for main model
is_fuse_shared_experts_layer = rocm_aiter_moe_shared_expert_enabled and ( is_fusion_moe_shared_experts_layer = (
"mlp.shared_experts" in name rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
) )
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
...@@ -1495,7 +1495,7 @@ class DeepseekV2ForCausalLM( ...@@ -1495,7 +1495,7 @@ class DeepseekV2ForCausalLM(
# for mlp.experts[0].gate_gate_up_proj, which breaks load. # for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict: if ("mlp.experts." in name) and name not in params_dict:
continue continue
if is_fuse_shared_experts_layer: if is_fusion_moe_shared_experts_layer:
continue continue
name_mapped = name.replace(weight_name, param_name) name_mapped = name.replace(weight_name, param_name)
...@@ -1531,7 +1531,7 @@ class DeepseekV2ForCausalLM( ...@@ -1531,7 +1531,7 @@ class DeepseekV2ForCausalLM(
# appended expert slots mlp.experts.{n_routed_experts + j}.* # appended expert slots mlp.experts.{n_routed_experts + j}.*
# accordingly. # accordingly.
num_chunks = 1 num_chunks = 1
if is_fuse_shared_experts_layer: if is_fusion_moe_shared_experts_layer:
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1 num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
# Determine split axis based on op type # Determine split axis based on op type
# gate/up: ColumnParallel → split along dim 0 # gate/up: ColumnParallel → split along dim 0
...@@ -1548,7 +1548,7 @@ class DeepseekV2ForCausalLM( ...@@ -1548,7 +1548,7 @@ class DeepseekV2ForCausalLM(
chunk_name = name chunk_name = name
weight_to_load = loaded_weight weight_to_load = loaded_weight
if is_fuse_shared_experts_layer: if is_fusion_moe_shared_experts_layer:
if split_dim == 0: if split_dim == 0:
weight_to_load = loaded_weight[ weight_to_load = loaded_weight[
j * chunk_size : (j + 1) * chunk_size, : j * chunk_size : (j + 1) * chunk_size, :
...@@ -1599,7 +1599,7 @@ class DeepseekV2ForCausalLM( ...@@ -1599,7 +1599,7 @@ class DeepseekV2ForCausalLM(
return_success=True, return_success=True,
) )
if success: if success:
if not is_fuse_shared_experts_layer: if not is_fusion_moe_shared_experts_layer:
name = name_mapped name = name_mapped
else: else:
loaded_params.add(name_mapped) loaded_params.add(name_mapped)
...@@ -1628,7 +1628,7 @@ class DeepseekV2ForCausalLM( ...@@ -1628,7 +1628,7 @@ class DeepseekV2ForCausalLM(
param, "weight_loader", default_weight_loader param, "weight_loader", default_weight_loader
) )
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if not is_fuse_shared_experts_layer: if not is_fusion_moe_shared_experts_layer:
loaded_params.add(name) loaded_params.add(name)
return loaded_params return loaded_params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment