Commit 9135afe4 authored by 王敏's avatar 王敏
Browse files

优化epsp代码

parent 76695c0a
......@@ -1007,6 +1007,8 @@ class DeepseekV2DecoderLayer(nn.Module):
prefix=f"{prefix}.mlp",
)
self.enable_ep_sp = isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1
self.is_mtp_layer = False
if self.layer_idx == config.num_hidden_layers:
self.is_mtp_layer = True
......@@ -1169,9 +1171,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
if not self.is_mtp_layer:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1 and \
if not self.is_mtp_layer and self.enable_ep_sp and \
self.layer_idx > self.config.first_k_dense_replace:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
......@@ -1180,9 +1180,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states=hidden_states,
)
if not self.is_mtp_layer:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
if not self.is_mtp_layer and self.enable_ep_sp:
if self.layer_idx == self.config.first_k_dense_replace:
residual = residual.tensor_split(self.tp_size)[self.tp_rank]
......@@ -1213,24 +1211,20 @@ class DeepseekV2DecoderLayer(nn.Module):
residual = hidden_states[self.dp_rank*new_bs: (self.dp_rank+1)*new_bs, :]
hidden_states = self.post_attention_layernorm(hidden_states)
if self.is_mtp_layer:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
if self.is_mtp_layer and self.enable_ep_sp:
ori_bs = hidden_states.shape[0]
pad_size = (ori_bs + self.tp_size - 1) // self.tp_size * self.tp_size - ori_bs
if pad_size > 0:
hidden_states = torch.nn.functional.pad(hidden_states.contiguous(), [0, 0, 0, pad_size], value=0).contiguous()
hidden_states = torch.nn.functional.pad(hidden_states, [0, 0, 0, pad_size], value=0)
new_bs = (ori_bs+pad_size) // self.tp_size
hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :].contiguous()
hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :]
hidden_states = self.mlp(hidden_states)
if self.enable_dp_attention:
hidden_states = dp_reduce_scatter_tensor(hidden_states)
if self.is_mtp_layer:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
if self.is_mtp_layer and self.enable_ep_sp:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
hidden_states = hidden_states[:ori_bs, :]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment