Unverified Commit f4560373 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Utilize static dispatching for communicator (#6577)

parent b2388433
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from typing import Dict, Optional, Tuple from functools import partial
from typing import Dict, Optional
import torch.distributed import torch.distributed
...@@ -145,6 +146,36 @@ class LayerCommunicator: ...@@ -145,6 +146,36 @@ class LayerCommunicator:
ScatterMode.FULL: self.tp_size, ScatterMode.FULL: self.tp_size,
} }
self._context = _Context(
process_group_sizes=self.process_group_sizes,
attn_tp_rank=self.attn_tp_rank,
attn_tp_size=self.attn_tp_size,
local_attn_dp_size=self.local_attn_dp_size,
tp_size=self.tp_size,
)
self._communicate_simple_fn = _CommunicateSimpleFn.get_fn(
input_mode=self.layer_scatter_modes.layer_input_mode,
output_mode=self.layer_scatter_modes.attn_mode,
context=self._context,
)
self._communicate_with_all_reduce_and_layer_norm_fn = (
_CommunicateWithAllReduceAndLayerNormFn.get_fn(
hidden_states_input_mode=self.layer_scatter_modes.attn_mode,
residual_input_mode=self.layer_scatter_modes.layer_input_mode,
hidden_states_output_mode=self.layer_scatter_modes.mlp_mode,
residual_output_mode=self.layer_scatter_modes.middle_residual_mode,
context=self._context,
)
)
self._communicate_summable_tensor_pair_fn = (
_CommunicateSummableTensorPairFn.get_fn(
hidden_states_input_mode=self.layer_scatter_modes.mlp_mode,
residual_input_mode=self.layer_scatter_modes.middle_residual_mode,
output_mode=self.layer_scatter_modes.layer_output_mode,
context=self._context,
)
)
def prepare_attn( def prepare_attn(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -160,12 +191,10 @@ class LayerCommunicator: ...@@ -160,12 +191,10 @@ class LayerCommunicator:
else: else:
hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = _communicate_simple( hidden_states = self._communicate_simple_fn(
hidden_states=hidden_states, hidden_states=hidden_states,
forward_batch=forward_batch, forward_batch=forward_batch,
input_mode=self.layer_scatter_modes.layer_input_mode, context=self._context,
output_mode=self.layer_scatter_modes.attn_mode,
context=self._compute_context(forward_batch),
) )
return hidden_states, residual return hidden_states, residual
...@@ -176,16 +205,12 @@ class LayerCommunicator: ...@@ -176,16 +205,12 @@ class LayerCommunicator:
residual: torch.Tensor, residual: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
): ):
return _communicate_with_all_reduce_and_layer_norm( return self._communicate_with_all_reduce_and_layer_norm_fn(
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual,
forward_batch=forward_batch, forward_batch=forward_batch,
hidden_states_input_mode=self.layer_scatter_modes.attn_mode,
residual_input_mode=self.layer_scatter_modes.layer_input_mode,
hidden_states_output_mode=self.layer_scatter_modes.mlp_mode,
residual_output_mode=self.layer_scatter_modes.middle_residual_mode,
layernorm=self.post_attention_layernorm, layernorm=self.post_attention_layernorm,
context=self._compute_context(forward_batch), context=self._context,
) )
def postprocess_layer( def postprocess_layer(
...@@ -194,58 +219,16 @@ class LayerCommunicator: ...@@ -194,58 +219,16 @@ class LayerCommunicator:
residual: torch.Tensor, residual: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
): ):
return _communicate_summable_tensor_pair( return self._communicate_summable_tensor_pair_fn(
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual,
forward_batch=forward_batch, forward_batch=forward_batch,
hidden_states_input_mode=self.layer_scatter_modes.mlp_mode, context=self._context,
residual_input_mode=self.layer_scatter_modes.middle_residual_mode,
output_mode=self.layer_scatter_modes.layer_output_mode,
context=self._compute_context(forward_batch),
)
def _compute_context(self, forward_batch: ForwardBatch):
return _Context(
num_tokens_of_mode=_compute_num_tokens_of_mode(
forward_batch,
attn_tp_rank=self.attn_tp_rank,
attn_tp_size=self.attn_tp_size,
),
process_group_sizes=self.process_group_sizes,
attn_tp_rank=self.attn_tp_rank,
attn_tp_size=self.attn_tp_size,
local_attn_dp_size=self.local_attn_dp_size,
tp_size=self.tp_size,
) )
def _compute_num_tokens_of_mode(
forward_batch: ForwardBatch, attn_tp_rank: int, attn_tp_size: int
):
tp_attn_full_num_tokens = forward_batch.input_ids.shape[0]
return {
ScatterMode.SCATTERED: _torch_tensor_split_len(
tp_attn_full_num_tokens, attn_tp_size, attn_tp_rank
),
ScatterMode.TP_ATTN_FULL: tp_attn_full_num_tokens,
ScatterMode.FULL: (
forward_batch.gathered_buffer.shape[0]
if global_server_args_dict["enable_dp_attention"]
else forward_batch.input_ids.shape[0]
),
}
def _torch_tensor_split_len(tensor_len: int, n: int, output_index: int):
if output_index < int(tensor_len % n):
return int(tensor_len / n) + 1
else:
return int(tensor_len / n)
@dataclass @dataclass
class _Context: class _Context:
num_tokens_of_mode: Dict["ScatterMode", int]
process_group_sizes: Dict["ScatterMode", int] process_group_sizes: Dict["ScatterMode", int]
attn_tp_rank: int attn_tp_rank: int
attn_tp_size: int attn_tp_size: int
...@@ -255,41 +238,38 @@ class _Context: ...@@ -255,41 +238,38 @@ class _Context:
def is_same_group_size(self, a: "ScatterMode", b: "ScatterMode"): def is_same_group_size(self, a: "ScatterMode", b: "ScatterMode"):
return self.process_group_sizes[a] == self.process_group_sizes[b] return self.process_group_sizes[a] == self.process_group_sizes[b]
def check_shape(self, x: torch.Tensor, mode: ScatterMode):
if x is None:
return
actual_num_tokens = x.shape[0]
expect_num_tokens = self.num_tokens_of_mode[mode]
assert (
actual_num_tokens == expect_num_tokens
), f"{actual_num_tokens=} {expect_num_tokens=} {mode=} {x.shape=} {self.num_tokens_of_mode=} {self.process_group_sizes=}"
return x
def check_shapes(
self, xs: Tuple[torch.Tensor, ...], modes: Tuple[ScatterMode, ...]
) -> Tuple[torch.Tensor, ...]:
return tuple(
[self.check_shape(x, mode) for x, mode in zip(xs, modes, strict=True)]
)
def _communicate_simple( class _CommunicateSimpleFn:
hidden_states: torch.Tensor, @staticmethod
forward_batch: ForwardBatch, def get_fn(
input_mode: ScatterMode, input_mode: ScatterMode,
output_mode: ScatterMode, output_mode: ScatterMode,
context: _Context, context: _Context,
) -> torch.Tensor: ):
def _inner():
nonlocal hidden_states
if context.is_same_group_size(input_mode, output_mode): if context.is_same_group_size(input_mode, output_mode):
return hidden_states return _CommunicateSimpleFn._trivial
if (input_mode == ScatterMode.SCATTERED) and ( if (input_mode == ScatterMode.SCATTERED) and (
output_mode == ScatterMode.TP_ATTN_FULL output_mode == ScatterMode.TP_ATTN_FULL
): ):
return _CommunicateSimpleFn._scattered_to_tp_attn_full
raise NotImplementedError(f"{input_mode=} {output_mode=}")
@staticmethod
def _trivial(
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
context: _Context,
) -> torch.Tensor:
return hidden_states
@staticmethod
def _scattered_to_tp_attn_full(
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
context: _Context,
) -> torch.Tensor:
hidden_states, local_hidden_states = ( hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states, hidden_states,
...@@ -300,30 +280,21 @@ def _communicate_simple( ...@@ -300,30 +280,21 @@ def _communicate_simple(
) )
return hidden_states return hidden_states
raise NotImplementedError(f"{input_mode=} {output_mode=}")
context.check_shape(hidden_states, input_mode)
return context.check_shape(_inner(), output_mode)
class _CommunicateWithAllReduceAndLayerNormFn:
"""Besides communication, needs to
1. All reduce in tp_attn_group on hidden_states
2. Apply layer norm
"""
def _communicate_with_all_reduce_and_layer_norm( @staticmethod
hidden_states: torch.Tensor, def get_fn(
residual: torch.Tensor,
hidden_states_input_mode: ScatterMode, hidden_states_input_mode: ScatterMode,
residual_input_mode: ScatterMode, residual_input_mode: ScatterMode,
hidden_states_output_mode: ScatterMode, hidden_states_output_mode: ScatterMode,
residual_output_mode: ScatterMode, residual_output_mode: ScatterMode,
forward_batch: ForwardBatch,
layernorm: torch.nn.Module,
context: _Context, context: _Context,
): ):
"""Besides communication, needs to
1. All reduce in tp_attn_group on hidden_states
2. Apply layer norm
"""
def _inner():
nonlocal hidden_states, residual
if ( if (
context.is_same_group_size( context.is_same_group_size(
...@@ -332,16 +303,53 @@ def _communicate_with_all_reduce_and_layer_norm( ...@@ -332,16 +303,53 @@ def _communicate_with_all_reduce_and_layer_norm(
and context.is_same_group_size(residual_input_mode, residual_output_mode) and context.is_same_group_size(residual_input_mode, residual_output_mode)
and context.attn_tp_size == 1 and context.attn_tp_size == 1
): ):
# TODO move these `if shape != 0` into LayerNorm itself return _CommunicateWithAllReduceAndLayerNormFn._simple
if hidden_states.shape[0] != 0:
hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual
if ( if (
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL) (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
and (residual_input_mode == ScatterMode.TP_ATTN_FULL) and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
and (hidden_states_output_mode == ScatterMode.FULL) and (hidden_states_output_mode == ScatterMode.FULL)
and (residual_output_mode == ScatterMode.TP_ATTN_FULL) and (residual_output_mode == ScatterMode.TP_ATTN_FULL)
):
return _CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states
if (
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
and (
residual_input_mode in [ScatterMode.SCATTERED, ScatterMode.TP_ATTN_FULL]
)
and (hidden_states_output_mode == ScatterMode.SCATTERED)
and (residual_output_mode == ScatterMode.SCATTERED)
):
return partial(
_CommunicateWithAllReduceAndLayerNormFn._scatter_hidden_states_and_residual,
residual_input_mode=residual_input_mode,
)
raise NotImplementedError(
f"{hidden_states_input_mode=} {residual_input_mode=} {residual_output_mode=} {residual_output_mode=}"
)
@staticmethod
def _simple(
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
layernorm: torch.nn.Module,
context: _Context,
):
# TODO move these `if shape != 0` into LayerNorm itself
if hidden_states.shape[0] != 0:
hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual
@staticmethod
def _gather_hidden_states(
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
layernorm: torch.nn.Module,
context: _Context,
): ):
if context.local_attn_dp_size != 1: if context.local_attn_dp_size != 1:
if context.attn_tp_rank == 0: if context.attn_tp_rank == 0:
...@@ -359,60 +367,75 @@ def _communicate_with_all_reduce_and_layer_norm( ...@@ -359,60 +367,75 @@ def _communicate_with_all_reduce_and_layer_norm(
hidden_states, residual = layernorm(hidden_states, residual) hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual return hidden_states, residual
if ( @staticmethod
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL) def _scatter_hidden_states_and_residual(
and ( hidden_states: torch.Tensor,
residual_input_mode in [ScatterMode.SCATTERED, ScatterMode.TP_ATTN_FULL] residual: torch.Tensor,
) forward_batch: ForwardBatch,
and (hidden_states_output_mode == ScatterMode.SCATTERED) layernorm: torch.nn.Module,
and (residual_output_mode == ScatterMode.SCATTERED) context: _Context,
*,
residual_input_mode,
): ):
tensor_list = list(hidden_states.tensor_split(context.attn_tp_size)) tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))
hidden_states = tensor_list[context.attn_tp_rank] hidden_states = tensor_list[context.attn_tp_rank]
attn_tp_reduce_scatter(hidden_states, tensor_list) attn_tp_reduce_scatter(hidden_states, tensor_list)
if residual_input_mode == ScatterMode.TP_ATTN_FULL: if residual_input_mode == ScatterMode.TP_ATTN_FULL:
residual = residual.tensor_split(context.attn_tp_size)[ residual = residual.tensor_split(context.attn_tp_size)[context.attn_tp_rank]
context.attn_tp_rank
]
if hidden_states.shape[0] != 0: if hidden_states.shape[0] != 0:
hidden_states, residual = layernorm(hidden_states, residual) hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual return hidden_states, residual
raise NotImplementedError(
f"{hidden_states_input_mode=} {residual_input_mode=} {residual_output_mode=} {residual_output_mode=}"
)
context.check_shapes(
(hidden_states, residual), (hidden_states_input_mode, residual_input_mode)
)
return context.check_shapes(
_inner(), (hidden_states_output_mode, residual_output_mode)
)
class _CommunicateSummableTensorPairFn:
def _communicate_summable_tensor_pair( @staticmethod
hidden_states: torch.Tensor, def get_fn(
residual: torch.Tensor,
forward_batch: ForwardBatch,
hidden_states_input_mode: ScatterMode, hidden_states_input_mode: ScatterMode,
residual_input_mode: ScatterMode, residual_input_mode: ScatterMode,
output_mode: ScatterMode, output_mode: ScatterMode,
context: _Context, context: _Context,
): ):
"""It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed.""" """It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed."""
def _inner():
nonlocal hidden_states, residual
if context.is_same_group_size( if context.is_same_group_size(
hidden_states_input_mode, output_mode hidden_states_input_mode, output_mode
) and context.is_same_group_size(residual_input_mode, output_mode): ) and context.is_same_group_size(residual_input_mode, output_mode):
return hidden_states, residual return _CommunicateSummableTensorPairFn._trivial
if ( if (
(hidden_states_input_mode == ScatterMode.FULL) (hidden_states_input_mode == ScatterMode.FULL)
and (residual_input_mode == ScatterMode.TP_ATTN_FULL) and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
and (output_mode == ScatterMode.TP_ATTN_FULL) and (output_mode == ScatterMode.TP_ATTN_FULL)
):
return _CommunicateSummableTensorPairFn._scatter_hidden_states
if (
(hidden_states_input_mode == ScatterMode.SCATTERED)
and (residual_input_mode == ScatterMode.SCATTERED)
and (output_mode == ScatterMode.TP_ATTN_FULL)
):
return _CommunicateSummableTensorPairFn._gather
raise NotImplementedError(
f"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}"
)
@staticmethod
def _trivial(
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
context: _Context,
):
return hidden_states, residual
@staticmethod
def _scatter_hidden_states(
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
context: _Context,
): ):
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# important: forward batch.gathered_buffer is used both after scatter and after gather. # important: forward batch.gathered_buffer is used both after scatter and after gather.
...@@ -424,10 +447,12 @@ def _communicate_summable_tensor_pair( ...@@ -424,10 +447,12 @@ def _communicate_summable_tensor_pair(
dp_scatter(hidden_states, global_hidden_states, forward_batch) dp_scatter(hidden_states, global_hidden_states, forward_batch)
return hidden_states, residual return hidden_states, residual
if ( @staticmethod
(hidden_states_input_mode == ScatterMode.SCATTERED) def _gather(
and (residual_input_mode == ScatterMode.SCATTERED) hidden_states: torch.Tensor,
and (output_mode == ScatterMode.TP_ATTN_FULL) residual: torch.Tensor,
forward_batch: ForwardBatch,
context: _Context,
): ):
hidden_states += residual hidden_states += residual
residual = None residual = None
...@@ -440,12 +465,3 @@ def _communicate_summable_tensor_pair( ...@@ -440,12 +465,3 @@ def _communicate_summable_tensor_pair(
local_hidden_states, local_hidden_states,
) )
return hidden_states, residual return hidden_states, residual
raise NotImplementedError(
f"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}"
)
context.check_shapes(
(hidden_states, residual), (hidden_states_input_mode, residual_input_mode)
)
return context.check_shapes(_inner(), (output_mode, output_mode))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment