Unverified Commit 32f28154 authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

Do layernorm before allgather for DP attention (#8631)

parent f7b2853f
...@@ -404,11 +404,21 @@ class CommunicateWithAllReduceAndLayerNormFn: ...@@ -404,11 +404,21 @@ class CommunicateWithAllReduceAndLayerNormFn:
if context.attn_dp_size != 1: if context.attn_dp_size != 1:
if context.attn_tp_rank == 0: if context.attn_tp_rank == 0:
hidden_states += residual hidden_states += residual
# Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
use_layer_norm_before_gather = context.attn_tp_size == 1
if use_layer_norm_before_gather:
residual.copy_(hidden_states)
if hidden_states.shape[0] != 0:
hidden_states = layernorm(hidden_states)
hidden_states, local_hidden_states = ( hidden_states, local_hidden_states = (
forward_batch.gathered_buffer, forward_batch.gathered_buffer,
hidden_states, hidden_states,
) )
dp_gather_partial(hidden_states, local_hidden_states, forward_batch) dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
if not use_layer_norm_before_gather:
dp_scatter(residual, hidden_states, forward_batch) dp_scatter(residual, hidden_states, forward_batch)
if hidden_states.shape[0] != 0: if hidden_states.shape[0] != 0:
hidden_states = layernorm(hidden_states) hidden_states = layernorm(hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment