Unverified Commit a303325f authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix DeepSeek bug causing 2.2% MMLU drop when TP!=DP (#4883)


Co-authored-by: default avatarch-wan <cwan39@gatech.edu>
parent 42873eac
...@@ -1102,6 +1102,10 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1102,6 +1102,10 @@ class DeepseekV2DecoderLayer(nn.Module):
else: else:
hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states, residual = self.input_layernorm(hidden_states, residual)
assert not (
self.attn_tp_size != 1 and self.input_is_scattered
), "moe_layer_freq > 1 is not supported when attn_tp_size > 1"
# Self Attention # Self Attention
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
...@@ -1109,22 +1113,6 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1109,22 +1113,6 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch=forward_batch, forward_batch=forward_batch,
) )
if self.attn_tp_size != 1 and self.input_is_scattered:
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
residual, local_residual = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
residual,
)
tp_all_gather(
list(residual.tensor_split(self.attn_tp_size)), local_residual
)
# Gather # Gather
if get_tensor_model_parallel_world_size() > 1: if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce # all gather and all reduce
...@@ -1223,6 +1211,8 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1223,6 +1211,8 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
if self.is_last_layer and self.attn_tp_size != 1: if self.is_last_layer and self.attn_tp_size != 1:
hidden_states += residual
residual = None
hidden_states, local_hidden_states = ( hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states, hidden_states,
...@@ -1230,19 +1220,11 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1230,19 +1220,11 @@ class DeepseekV2DecoderLayer(nn.Module):
tp_all_gather( tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
) )
residual, local_residual = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
residual,
)
tp_all_gather(
list(residual.tensor_split(self.attn_tp_size)), local_residual
)
return hidden_states, residual return hidden_states, residual
class DeepseekV2Model(nn.Module): class DeepseekV2Model(nn.Module):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
def __init__( def __init__(
...@@ -1296,6 +1278,9 @@ class DeepseekV2Model(nn.Module): ...@@ -1296,6 +1278,9 @@ class DeepseekV2Model(nn.Module):
positions, hidden_states, forward_batch, residual positions, hidden_states, forward_batch, residual
) )
if not forward_batch.forward_mode.is_idle(): if not forward_batch.forward_mode.is_idle():
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment