"vscode:/vscode.git/clone" did not exist on "ca3ea51bde6c22d0afb3aa0a3fdba6d568095a0a"
Commit 9135afe4 authored by 王敏's avatar 王敏
Browse files

优化epsp代码

parent 76695c0a
...@@ -1007,6 +1007,8 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1007,6 +1007,8 @@ class DeepseekV2DecoderLayer(nn.Module):
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
self.enable_ep_sp = isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1
self.is_mtp_layer = False self.is_mtp_layer = False
if self.layer_idx == config.num_hidden_layers: if self.layer_idx == config.num_hidden_layers:
self.is_mtp_layer = True self.is_mtp_layer = True
...@@ -1169,9 +1171,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1169,9 +1171,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states, residual = self.input_layernorm( hidden_states, residual = self.input_layernorm(
hidden_states, residual) hidden_states, residual)
if not self.is_mtp_layer: if not self.is_mtp_layer and self.enable_ep_sp and \
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1 and \
self.layer_idx > self.config.first_k_dense_replace: self.layer_idx > self.config.first_k_dense_replace:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0) hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
...@@ -1180,9 +1180,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1180,9 +1180,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
) )
if not self.is_mtp_layer: if not self.is_mtp_layer and self.enable_ep_sp:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
if self.layer_idx == self.config.first_k_dense_replace: if self.layer_idx == self.config.first_k_dense_replace:
residual = residual.tensor_split(self.tp_size)[self.tp_rank] residual = residual.tensor_split(self.tp_size)[self.tp_rank]
...@@ -1213,24 +1211,20 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1213,24 +1211,20 @@ class DeepseekV2DecoderLayer(nn.Module):
residual = hidden_states[self.dp_rank*new_bs: (self.dp_rank+1)*new_bs, :] residual = hidden_states[self.dp_rank*new_bs: (self.dp_rank+1)*new_bs, :]
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states)
if self.is_mtp_layer: if self.is_mtp_layer and self.enable_ep_sp:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
ori_bs = hidden_states.shape[0] ori_bs = hidden_states.shape[0]
pad_size = (ori_bs + self.tp_size - 1) // self.tp_size * self.tp_size - ori_bs pad_size = (ori_bs + self.tp_size - 1) // self.tp_size * self.tp_size - ori_bs
if pad_size > 0: if pad_size > 0:
hidden_states = torch.nn.functional.pad(hidden_states.contiguous(), [0, 0, 0, pad_size], value=0).contiguous() hidden_states = torch.nn.functional.pad(hidden_states, [0, 0, 0, pad_size], value=0)
new_bs = (ori_bs+pad_size) // self.tp_size new_bs = (ori_bs+pad_size) // self.tp_size
hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :].contiguous() hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :]
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
if self.enable_dp_attention: if self.enable_dp_attention:
hidden_states = dp_reduce_scatter_tensor(hidden_states) hidden_states = dp_reduce_scatter_tensor(hidden_states)
if self.is_mtp_layer: if self.is_mtp_layer and self.enable_ep_sp:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0) hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
hidden_states = hidden_states[:ori_bs, :] hidden_states = hidden_states[:ori_bs, :]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment