Unverified Commit 8dc191f2 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix one wasted kernel in DeepSeek and minor refactor (#6316)

parent 64825b83
...@@ -1336,28 +1336,16 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1336,28 +1336,16 @@ class DeepseekV2DecoderLayer(nn.Module):
) )
if self.attn_tp_size != 1: if self.attn_tp_size != 1:
if self.input_is_scattered: tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) hidden_states = tensor_list[self.attn_tp_rank]
hidden_states = tensor_list[self.attn_tp_rank] attn_tp_reduce_scatter(hidden_states, tensor_list)
attn_tp_reduce_scatter(hidden_states, tensor_list) if not self.input_is_scattered:
if hidden_states.shape[0] != 0: residual = residual.tensor_split(self.attn_tp_size)[self.attn_tp_rank]
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual if hidden_states.shape[0] != 0:
) hidden_states, residual = self.post_attention_layernorm(
else: hidden_states, residual
if self.attn_tp_rank == 0: )
hidden_states += residual
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank]
attn_tp_reduce_scatter(hidden_states, tensor_list)
residual = hidden_states
if hidden_states.shape[0] != 0:
hidden_states = self.post_attention_layernorm(hidden_states)
else:
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
if not ( if not (
self._enable_moe_dense_fully_dp() self._enable_moe_dense_fully_dp()
...@@ -1859,7 +1847,6 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1859,7 +1847,6 @@ class DeepseekV2ForCausalLM(nn.Module):
q_a_proj_name in cached_a_proj q_a_proj_name in cached_a_proj
and kv_a_proj_name in cached_a_proj and kv_a_proj_name in cached_a_proj
): ):
q_a_proj_weight = cached_a_proj[q_a_proj_name] q_a_proj_weight = cached_a_proj[q_a_proj_name]
kv_a_proj_weight = cached_a_proj[kv_a_proj_name] kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
fused_weight = torch.cat( fused_weight = torch.cat(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment