"examples/community/pipeline_stable_diffusion_boxdiff.py" did not exist on "aa82df52e719f22a51f2881ebe15d2904586160a"
Unverified Commit a303325f authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix DeepSeek bug causing 2.2% MMLU drop when TP!=DP (#4883)


Co-authored-by: default avatarch-wan <cwan39@gatech.edu>
parent 42873eac
......@@ -1102,6 +1102,10 @@ class DeepseekV2DecoderLayer(nn.Module):
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
assert not (
self.attn_tp_size != 1 and self.input_is_scattered
), "moe_layer_freq > 1 is not supported when attn_tp_size > 1"
# Self Attention
hidden_states = self.self_attn(
positions=positions,
......@@ -1109,22 +1113,6 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch=forward_batch,
)
if self.attn_tp_size != 1 and self.input_is_scattered:
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
residual, local_residual = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
residual,
)
tp_all_gather(
list(residual.tensor_split(self.attn_tp_size)), local_residual
)
# Gather
if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce
......@@ -1223,6 +1211,8 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
if self.is_last_layer and self.attn_tp_size != 1:
hidden_states += residual
residual = None
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
......@@ -1230,19 +1220,11 @@ class DeepseekV2DecoderLayer(nn.Module):
tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
residual, local_residual = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
residual,
)
tp_all_gather(
list(residual.tensor_split(self.attn_tp_size)), local_residual
)
return hidden_states, residual
class DeepseekV2Model(nn.Module):
fall_back_to_pt_during_load = False
def __init__(
......@@ -1296,7 +1278,10 @@ class DeepseekV2Model(nn.Module):
positions, hidden_states, forward_batch, residual
)
if not forward_batch.forward_mode.is_idle():
hidden_states, _ = self.norm(hidden_states, residual)
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment