Unverified Commit 3c2274fb authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Implement gather before attn (#6378)

parent d2679f51
......@@ -226,13 +226,13 @@ class LayerCommunicator:
@dataclass
class CommunicateContext:
process_group_sizes: Dict["ScatterMode", int]
process_group_sizes: Dict[ScatterMode, int]
attn_tp_rank: int
attn_tp_size: int
local_attn_dp_size: int
tp_size: int
def is_same_group_size(self, a: "ScatterMode", b: "ScatterMode"):
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
return self.process_group_sizes[a] == self.process_group_sizes[b]
@classmethod
......@@ -244,6 +244,7 @@ class CommunicateContext:
process_group_sizes = {
ScatterMode.SCATTERED: 1,
ScatterMode.TP_ATTN_FULL: attn_tp_size,
# TODO: support --moe-dense-tp-size > 1
ScatterMode.FULL: tp_size,
}
return cls(
......@@ -323,11 +324,16 @@ class CommunicateWithAllReduceAndLayerNormFn:
if (
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
and (
residual_input_mode in [ScatterMode.SCATTERED, ScatterMode.TP_ATTN_FULL]
)
and (hidden_states_output_mode == ScatterMode.FULL)
and (residual_output_mode == ScatterMode.TP_ATTN_FULL)
):
return CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states
return partial(
CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states_and_residual,
residual_input_mode=residual_input_mode,
)
if (
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
......@@ -360,13 +366,25 @@ class CommunicateWithAllReduceAndLayerNormFn:
return hidden_states, residual
@staticmethod
def _gather_hidden_states(
def _gather_hidden_states_and_residual(
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
layernorm: torch.nn.Module,
context: CommunicateContext,
*,
residual_input_mode,
):
if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
residual, local_residual = (
forward_batch.gathered_buffer[
: forward_batch.input_ids.shape[0]
].clone(),
residual,
)
attn_tp_all_gather(
list(residual.tensor_split(context.attn_tp_size)), local_residual
)
if context.local_attn_dp_size != 1:
if context.attn_tp_rank == 0:
hidden_states += residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment