Unverified Commit 3c2274fb authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Implement gather before attn (#6378)

parent d2679f51
...@@ -226,13 +226,13 @@ class LayerCommunicator: ...@@ -226,13 +226,13 @@ class LayerCommunicator:
@dataclass @dataclass
class CommunicateContext: class CommunicateContext:
process_group_sizes: Dict["ScatterMode", int] process_group_sizes: Dict[ScatterMode, int]
attn_tp_rank: int attn_tp_rank: int
attn_tp_size: int attn_tp_size: int
local_attn_dp_size: int local_attn_dp_size: int
tp_size: int tp_size: int
def is_same_group_size(self, a: "ScatterMode", b: "ScatterMode"): def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
return self.process_group_sizes[a] == self.process_group_sizes[b] return self.process_group_sizes[a] == self.process_group_sizes[b]
@classmethod @classmethod
...@@ -244,6 +244,7 @@ class CommunicateContext: ...@@ -244,6 +244,7 @@ class CommunicateContext:
process_group_sizes = { process_group_sizes = {
ScatterMode.SCATTERED: 1, ScatterMode.SCATTERED: 1,
ScatterMode.TP_ATTN_FULL: attn_tp_size, ScatterMode.TP_ATTN_FULL: attn_tp_size,
# TODO: support --moe-dense-tp-size > 1
ScatterMode.FULL: tp_size, ScatterMode.FULL: tp_size,
} }
return cls( return cls(
...@@ -323,11 +324,16 @@ class CommunicateWithAllReduceAndLayerNormFn: ...@@ -323,11 +324,16 @@ class CommunicateWithAllReduceAndLayerNormFn:
if ( if (
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL) (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
and (residual_input_mode == ScatterMode.TP_ATTN_FULL) and (
residual_input_mode in [ScatterMode.SCATTERED, ScatterMode.TP_ATTN_FULL]
)
and (hidden_states_output_mode == ScatterMode.FULL) and (hidden_states_output_mode == ScatterMode.FULL)
and (residual_output_mode == ScatterMode.TP_ATTN_FULL) and (residual_output_mode == ScatterMode.TP_ATTN_FULL)
): ):
return CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states return partial(
CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states_and_residual,
residual_input_mode=residual_input_mode,
)
if ( if (
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL) (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
...@@ -360,13 +366,25 @@ class CommunicateWithAllReduceAndLayerNormFn: ...@@ -360,13 +366,25 @@ class CommunicateWithAllReduceAndLayerNormFn:
return hidden_states, residual return hidden_states, residual
@staticmethod @staticmethod
def _gather_hidden_states( def _gather_hidden_states_and_residual(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: torch.Tensor, residual: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
layernorm: torch.nn.Module, layernorm: torch.nn.Module,
context: CommunicateContext, context: CommunicateContext,
*,
residual_input_mode,
): ):
if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
residual, local_residual = (
forward_batch.gathered_buffer[
: forward_batch.input_ids.shape[0]
].clone(),
residual,
)
attn_tp_all_gather(
list(residual.tensor_split(context.attn_tp_size)), local_residual
)
if context.local_attn_dp_size != 1: if context.local_attn_dp_size != 1:
if context.attn_tp_rank == 0: if context.attn_tp_rank == 0:
hidden_states += residual hidden_states += residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment