Commit 6c6c9c0d authored by 王敏's avatar 王敏
Browse files

[Fix]修复PP模式pcp报错

parent 26645e58
...@@ -1409,6 +1409,9 @@ class DeepseekV2Model(nn.Module): ...@@ -1409,6 +1409,9 @@ class DeepseekV2Model(nn.Module):
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
if enable_mla_cp:
hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
residual = tensor_model_parallel_all_gather(residual.contiguous(), dim=0)
return IntermediateTensors( return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual} {"hidden_states": hidden_states, "residual": residual}
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment