Commit 9135afe4 authored by 王敏's avatar 王敏
Browse files

优化epsp代码

parent 76695c0a
...@@ -1007,6 +1007,8 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1007,6 +1007,8 @@ class DeepseekV2DecoderLayer(nn.Module):
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
self.enable_ep_sp = isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1
self.is_mtp_layer = False self.is_mtp_layer = False
if self.layer_idx == config.num_hidden_layers: if self.layer_idx == config.num_hidden_layers:
self.is_mtp_layer = True self.is_mtp_layer = True
...@@ -1169,24 +1171,20 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1169,24 +1171,20 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states, residual = self.input_layernorm( hidden_states, residual = self.input_layernorm(
hidden_states, residual) hidden_states, residual)
if not self.is_mtp_layer: if not self.is_mtp_layer and self.enable_ep_sp and \
if isinstance(self.mlp, self.layer_idx > self.config.first_k_dense_replace:
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1 and \ hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
self.layer_idx > self.config.first_k_dense_replace:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
) )
if not self.is_mtp_layer: if not self.is_mtp_layer and self.enable_ep_sp:
if isinstance(self.mlp, if self.layer_idx == self.config.first_k_dense_replace:
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1: residual = residual.tensor_split(self.tp_size)[self.tp_rank]
if self.layer_idx == self.config.first_k_dense_replace:
residual = residual.tensor_split(self.tp_size)[self.tp_rank]
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0) hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
if self.enable_dp_attention: if self.enable_dp_attention:
if self.tp_rank == 0: if self.tp_rank == 0:
...@@ -1213,26 +1211,22 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1213,26 +1211,22 @@ class DeepseekV2DecoderLayer(nn.Module):
residual = hidden_states[self.dp_rank*new_bs: (self.dp_rank+1)*new_bs, :] residual = hidden_states[self.dp_rank*new_bs: (self.dp_rank+1)*new_bs, :]
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states)
if self.is_mtp_layer: if self.is_mtp_layer and self.enable_ep_sp:
if isinstance(self.mlp, ori_bs = hidden_states.shape[0]
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1: pad_size = (ori_bs + self.tp_size - 1) // self.tp_size * self.tp_size - ori_bs
ori_bs = hidden_states.shape[0] if pad_size > 0:
pad_size = (ori_bs + self.tp_size - 1) // self.tp_size * self.tp_size - ori_bs hidden_states = torch.nn.functional.pad(hidden_states, [0, 0, 0, pad_size], value=0)
if pad_size > 0: new_bs = (ori_bs+pad_size) // self.tp_size
hidden_states = torch.nn.functional.pad(hidden_states.contiguous(), [0, 0, 0, pad_size], value=0).contiguous() hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :]
new_bs = (ori_bs+pad_size) // self.tp_size
hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :].contiguous()
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
if self.enable_dp_attention: if self.enable_dp_attention:
hidden_states = dp_reduce_scatter_tensor(hidden_states) hidden_states = dp_reduce_scatter_tensor(hidden_states)
if self.is_mtp_layer: if self.is_mtp_layer and self.enable_ep_sp:
if isinstance(self.mlp, hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1: hidden_states = hidden_states[:ori_bs, :]
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
hidden_states = hidden_states[:ori_bs, :]
if isinstance(self.mlp, if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16: DeepseekV2MLP) and hidden_states.dtype == torch.float16:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment