Commit a3776adc authored by 王敏's avatar 王敏
Browse files

解决PCP启动报错问题

parent b58514dd
...@@ -268,7 +268,7 @@ class DeepseekV2MLP(nn.Module): ...@@ -268,7 +268,7 @@ class DeepseekV2MLP(nn.Module):
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None *, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
): ):
enable_lightly_cp = get_forward_context().enable_lightly_cp enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_lightly_cp: if enable_lightly_cp:
if iqis is not None and iqis[0] is not None and iqis[1] is not None: if iqis is not None and iqis[0] is not None and iqis[1] is not None:
iqis = iqis_all_gather(iqis, tp_size=self.tp_size) iqis = iqis_all_gather(iqis, tp_size=self.tp_size)
else: else:
...@@ -294,6 +294,66 @@ class DeepseekV2MLP(nn.Module): ...@@ -294,6 +294,66 @@ class DeepseekV2MLP(nn.Module):
elif self.tp_size > 1: elif self.tp_size > 1:
x = tensor_model_parallel_all_reduce(x) x = tensor_model_parallel_all_reduce(x)
return x return x
class DeepseekV2SharedMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
reduce_results: bool = True,
is_sequence_parallel=False,
prefix: str = "",
) -> None:
super().__init__()
# If is_sequence_parallel, the input and output tensors are sharded
# across the ranks within the tp_group. In this case the weights are
# replicated and no collective ops are needed.
# Otherwise we use standard TP with an allreduce at the end.
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.down_proj"
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self,
x,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
):
if envs.USE_FUSED_RMS_QUANT:
gate_up, _ = self.gate_up_proj(x, iqis=iqis)
if envs.USE_FUSED_SILU_MUL_QUANT:
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
xq, xs = lm_fuse_silu_mul_quant(gate_up)
x, _ = self.down_proj(gate_up, iqis=(xq, xs))
else:
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
else:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class DeepseekV2MoE(nn.Module): class DeepseekV2MoE(nn.Module):
...@@ -366,7 +426,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -366,7 +426,7 @@ class DeepseekV2MoE(nn.Module):
else: else:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV2MLP( self.shared_experts = DeepseekV2SharedMLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment