Unverified Commit f4560373 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Utilize static dispatching for communicator (#6577)

parent b2388433
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from typing import Dict, Optional, Tuple from functools import partial
from typing import Dict, Optional
import torch.distributed import torch.distributed
...@@ -145,6 +146,36 @@ class LayerCommunicator: ...@@ -145,6 +146,36 @@ class LayerCommunicator:
ScatterMode.FULL: self.tp_size, ScatterMode.FULL: self.tp_size,
} }
self._context = _Context(
process_group_sizes=self.process_group_sizes,
attn_tp_rank=self.attn_tp_rank,
attn_tp_size=self.attn_tp_size,
local_attn_dp_size=self.local_attn_dp_size,
tp_size=self.tp_size,
)
self._communicate_simple_fn = _CommunicateSimpleFn.get_fn(
input_mode=self.layer_scatter_modes.layer_input_mode,
output_mode=self.layer_scatter_modes.attn_mode,
context=self._context,
)
self._communicate_with_all_reduce_and_layer_norm_fn = (
_CommunicateWithAllReduceAndLayerNormFn.get_fn(
hidden_states_input_mode=self.layer_scatter_modes.attn_mode,
residual_input_mode=self.layer_scatter_modes.layer_input_mode,
hidden_states_output_mode=self.layer_scatter_modes.mlp_mode,
residual_output_mode=self.layer_scatter_modes.middle_residual_mode,
context=self._context,
)
)
self._communicate_summable_tensor_pair_fn = (
_CommunicateSummableTensorPairFn.get_fn(
hidden_states_input_mode=self.layer_scatter_modes.mlp_mode,
residual_input_mode=self.layer_scatter_modes.middle_residual_mode,
output_mode=self.layer_scatter_modes.layer_output_mode,
context=self._context,
)
)
def prepare_attn( def prepare_attn(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -160,12 +191,10 @@ class LayerCommunicator: ...@@ -160,12 +191,10 @@ class LayerCommunicator:
else: else:
hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = _communicate_simple( hidden_states = self._communicate_simple_fn(
hidden_states=hidden_states, hidden_states=hidden_states,
forward_batch=forward_batch, forward_batch=forward_batch,
input_mode=self.layer_scatter_modes.layer_input_mode, context=self._context,
output_mode=self.layer_scatter_modes.attn_mode,
context=self._compute_context(forward_batch),
) )
return hidden_states, residual return hidden_states, residual
...@@ -176,16 +205,12 @@ class LayerCommunicator: ...@@ -176,16 +205,12 @@ class LayerCommunicator:
residual: torch.Tensor, residual: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
): ):
return _communicate_with_all_reduce_and_layer_norm( return self._communicate_with_all_reduce_and_layer_norm_fn(
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual,
forward_batch=forward_batch, forward_batch=forward_batch,
hidden_states_input_mode=self.layer_scatter_modes.attn_mode,
residual_input_mode=self.layer_scatter_modes.layer_input_mode,
hidden_states_output_mode=self.layer_scatter_modes.mlp_mode,
residual_output_mode=self.layer_scatter_modes.middle_residual_mode,
layernorm=self.post_attention_layernorm, layernorm=self.post_attention_layernorm,
context=self._compute_context(forward_batch), context=self._context,
) )
def postprocess_layer( def postprocess_layer(
...@@ -194,58 +219,16 @@ class LayerCommunicator: ...@@ -194,58 +219,16 @@ class LayerCommunicator:
residual: torch.Tensor, residual: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
): ):
return _communicate_summable_tensor_pair( return self._communicate_summable_tensor_pair_fn(
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual,
forward_batch=forward_batch, forward_batch=forward_batch,
hidden_states_input_mode=self.layer_scatter_modes.mlp_mode, context=self._context,
residual_input_mode=self.layer_scatter_modes.middle_residual_mode,
output_mode=self.layer_scatter_modes.layer_output_mode,
context=self._compute_context(forward_batch),
)
def _compute_context(self, forward_batch: ForwardBatch):
return _Context(
num_tokens_of_mode=_compute_num_tokens_of_mode(
forward_batch,
attn_tp_rank=self.attn_tp_rank,
attn_tp_size=self.attn_tp_size,
),
process_group_sizes=self.process_group_sizes,
attn_tp_rank=self.attn_tp_rank,
attn_tp_size=self.attn_tp_size,
local_attn_dp_size=self.local_attn_dp_size,
tp_size=self.tp_size,
) )
def _compute_num_tokens_of_mode(
forward_batch: ForwardBatch, attn_tp_rank: int, attn_tp_size: int
):
tp_attn_full_num_tokens = forward_batch.input_ids.shape[0]
return {
ScatterMode.SCATTERED: _torch_tensor_split_len(
tp_attn_full_num_tokens, attn_tp_size, attn_tp_rank
),
ScatterMode.TP_ATTN_FULL: tp_attn_full_num_tokens,
ScatterMode.FULL: (
forward_batch.gathered_buffer.shape[0]
if global_server_args_dict["enable_dp_attention"]
else forward_batch.input_ids.shape[0]
),
}
def _torch_tensor_split_len(tensor_len: int, n: int, output_index: int):
if output_index < int(tensor_len % n):
return int(tensor_len / n) + 1
else:
return int(tensor_len / n)
@dataclass @dataclass
class _Context: class _Context:
num_tokens_of_mode: Dict["ScatterMode", int]
process_group_sizes: Dict["ScatterMode", int] process_group_sizes: Dict["ScatterMode", int]
attn_tp_rank: int attn_tp_rank: int
attn_tp_size: int attn_tp_size: int
...@@ -255,75 +238,63 @@ class _Context: ...@@ -255,75 +238,63 @@ class _Context:
def is_same_group_size(self, a: "ScatterMode", b: "ScatterMode"): def is_same_group_size(self, a: "ScatterMode", b: "ScatterMode"):
return self.process_group_sizes[a] == self.process_group_sizes[b] return self.process_group_sizes[a] == self.process_group_sizes[b]
def check_shape(self, x: torch.Tensor, mode: ScatterMode):
if x is None:
return
actual_num_tokens = x.shape[0]
expect_num_tokens = self.num_tokens_of_mode[mode]
assert (
actual_num_tokens == expect_num_tokens
), f"{actual_num_tokens=} {expect_num_tokens=} {mode=} {x.shape=} {self.num_tokens_of_mode=} {self.process_group_sizes=}"
return x
def check_shapes(
self, xs: Tuple[torch.Tensor, ...], modes: Tuple[ScatterMode, ...]
) -> Tuple[torch.Tensor, ...]:
return tuple(
[self.check_shape(x, mode) for x, mode in zip(xs, modes, strict=True)]
)
def _communicate_simple(
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
input_mode: ScatterMode,
output_mode: ScatterMode,
context: _Context,
) -> torch.Tensor:
def _inner():
nonlocal hidden_states
class _CommunicateSimpleFn:
@staticmethod
def get_fn(
input_mode: ScatterMode,
output_mode: ScatterMode,
context: _Context,
):
if context.is_same_group_size(input_mode, output_mode): if context.is_same_group_size(input_mode, output_mode):
return hidden_states return _CommunicateSimpleFn._trivial
if (input_mode == ScatterMode.SCATTERED) and ( if (input_mode == ScatterMode.SCATTERED) and (
output_mode == ScatterMode.TP_ATTN_FULL output_mode == ScatterMode.TP_ATTN_FULL
): ):
hidden_states, local_hidden_states = ( return _CommunicateSimpleFn._scattered_to_tp_attn_full
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
attn_tp_all_gather(
list(hidden_states.tensor_split(context.attn_tp_size)),
local_hidden_states,
)
return hidden_states
raise NotImplementedError(f"{input_mode=} {output_mode=}") raise NotImplementedError(f"{input_mode=} {output_mode=}")
context.check_shape(hidden_states, input_mode) @staticmethod
return context.check_shape(_inner(), output_mode) def _trivial(
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
def _communicate_with_all_reduce_and_layer_norm( context: _Context,
hidden_states: torch.Tensor, ) -> torch.Tensor:
residual: torch.Tensor, return hidden_states
hidden_states_input_mode: ScatterMode,
residual_input_mode: ScatterMode, @staticmethod
hidden_states_output_mode: ScatterMode, def _scattered_to_tp_attn_full(
residual_output_mode: ScatterMode, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
layernorm: torch.nn.Module, context: _Context,
context: _Context, ) -> torch.Tensor:
): hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
attn_tp_all_gather(
list(hidden_states.tensor_split(context.attn_tp_size)),
local_hidden_states,
)
return hidden_states
class _CommunicateWithAllReduceAndLayerNormFn:
"""Besides communication, needs to """Besides communication, needs to
1. All reduce in tp_attn_group on hidden_states 1. All reduce in tp_attn_group on hidden_states
2. Apply layer norm 2. Apply layer norm
""" """
def _inner(): @staticmethod
nonlocal hidden_states, residual def get_fn(
hidden_states_input_mode: ScatterMode,
residual_input_mode: ScatterMode,
hidden_states_output_mode: ScatterMode,
residual_output_mode: ScatterMode,
context: _Context,
):
if ( if (
context.is_same_group_size( context.is_same_group_size(
...@@ -332,10 +303,7 @@ def _communicate_with_all_reduce_and_layer_norm( ...@@ -332,10 +303,7 @@ def _communicate_with_all_reduce_and_layer_norm(
and context.is_same_group_size(residual_input_mode, residual_output_mode) and context.is_same_group_size(residual_input_mode, residual_output_mode)
and context.attn_tp_size == 1 and context.attn_tp_size == 1
): ):
# TODO move these `if shape != 0` into LayerNorm itself return _CommunicateWithAllReduceAndLayerNormFn._simple
if hidden_states.shape[0] != 0:
hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual
if ( if (
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL) (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
...@@ -343,21 +311,7 @@ def _communicate_with_all_reduce_and_layer_norm( ...@@ -343,21 +311,7 @@ def _communicate_with_all_reduce_and_layer_norm(
and (hidden_states_output_mode == ScatterMode.FULL) and (hidden_states_output_mode == ScatterMode.FULL)
and (residual_output_mode == ScatterMode.TP_ATTN_FULL) and (residual_output_mode == ScatterMode.TP_ATTN_FULL)
): ):
if context.local_attn_dp_size != 1: return _CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states
if context.attn_tp_rank == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
dp_scatter(residual, hidden_states, forward_batch)
if hidden_states.shape[0] != 0:
hidden_states = layernorm(hidden_states)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual
if ( if (
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL) (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
...@@ -367,85 +321,147 @@ def _communicate_with_all_reduce_and_layer_norm( ...@@ -367,85 +321,147 @@ def _communicate_with_all_reduce_and_layer_norm(
and (hidden_states_output_mode == ScatterMode.SCATTERED) and (hidden_states_output_mode == ScatterMode.SCATTERED)
and (residual_output_mode == ScatterMode.SCATTERED) and (residual_output_mode == ScatterMode.SCATTERED)
): ):
tensor_list = list(hidden_states.tensor_split(context.attn_tp_size)) return partial(
hidden_states = tensor_list[context.attn_tp_rank] _CommunicateWithAllReduceAndLayerNormFn._scatter_hidden_states_and_residual,
attn_tp_reduce_scatter(hidden_states, tensor_list) residual_input_mode=residual_input_mode,
if residual_input_mode == ScatterMode.TP_ATTN_FULL: )
residual = residual.tensor_split(context.attn_tp_size)[
context.attn_tp_rank
]
if hidden_states.shape[0] != 0:
hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual
raise NotImplementedError( raise NotImplementedError(
f"{hidden_states_input_mode=} {residual_input_mode=} {residual_output_mode=} {residual_output_mode=}" f"{hidden_states_input_mode=} {residual_input_mode=} {residual_output_mode=} {residual_output_mode=}"
) )
context.check_shapes( @staticmethod
(hidden_states, residual), (hidden_states_input_mode, residual_input_mode) def _simple(
) hidden_states: torch.Tensor,
return context.check_shapes( residual: torch.Tensor,
_inner(), (hidden_states_output_mode, residual_output_mode) forward_batch: ForwardBatch,
) layernorm: torch.nn.Module,
context: _Context,
):
# TODO move these `if shape != 0` into LayerNorm itself
if hidden_states.shape[0] != 0:
hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual
@staticmethod
def _gather_hidden_states(
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
layernorm: torch.nn.Module,
context: _Context,
):
if context.local_attn_dp_size != 1:
if context.attn_tp_rank == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
dp_scatter(residual, hidden_states, forward_batch)
if hidden_states.shape[0] != 0:
hidden_states = layernorm(hidden_states)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual
@staticmethod
def _scatter_hidden_states_and_residual(
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
layernorm: torch.nn.Module,
context: _Context,
*,
residual_input_mode,
):
tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))
hidden_states = tensor_list[context.attn_tp_rank]
attn_tp_reduce_scatter(hidden_states, tensor_list)
if residual_input_mode == ScatterMode.TP_ATTN_FULL:
residual = residual.tensor_split(context.attn_tp_size)[context.attn_tp_rank]
if hidden_states.shape[0] != 0:
hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual
def _communicate_summable_tensor_pair( class _CommunicateSummableTensorPairFn:
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
hidden_states_input_mode: ScatterMode,
residual_input_mode: ScatterMode,
output_mode: ScatterMode,
context: _Context,
):
"""It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed."""
def _inner(): @staticmethod
nonlocal hidden_states, residual def get_fn(
hidden_states_input_mode: ScatterMode,
residual_input_mode: ScatterMode,
output_mode: ScatterMode,
context: _Context,
):
"""It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed."""
if context.is_same_group_size( if context.is_same_group_size(
hidden_states_input_mode, output_mode hidden_states_input_mode, output_mode
) and context.is_same_group_size(residual_input_mode, output_mode): ) and context.is_same_group_size(residual_input_mode, output_mode):
return hidden_states, residual return _CommunicateSummableTensorPairFn._trivial
if ( if (
(hidden_states_input_mode == ScatterMode.FULL) (hidden_states_input_mode == ScatterMode.FULL)
and (residual_input_mode == ScatterMode.TP_ATTN_FULL) and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
and (output_mode == ScatterMode.TP_ATTN_FULL) and (output_mode == ScatterMode.TP_ATTN_FULL)
): ):
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter return _CommunicateSummableTensorPairFn._scatter_hidden_states
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
return hidden_states, residual
if ( if (
(hidden_states_input_mode == ScatterMode.SCATTERED) (hidden_states_input_mode == ScatterMode.SCATTERED)
and (residual_input_mode == ScatterMode.SCATTERED) and (residual_input_mode == ScatterMode.SCATTERED)
and (output_mode == ScatterMode.TP_ATTN_FULL) and (output_mode == ScatterMode.TP_ATTN_FULL)
): ):
hidden_states += residual return _CommunicateSummableTensorPairFn._gather
residual = None
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
attn_tp_all_gather(
list(hidden_states.tensor_split(context.attn_tp_size)),
local_hidden_states,
)
return hidden_states, residual
raise NotImplementedError( raise NotImplementedError(
f"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}" f"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}"
) )
context.check_shapes( @staticmethod
(hidden_states, residual), (hidden_states_input_mode, residual_input_mode) def _trivial(
) hidden_states: torch.Tensor,
return context.check_shapes(_inner(), (output_mode, output_mode)) residual: torch.Tensor,
forward_batch: ForwardBatch,
context: _Context,
):
return hidden_states, residual
@staticmethod
def _scatter_hidden_states(
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
context: _Context,
):
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
return hidden_states, residual
@staticmethod
def _gather(
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
context: _Context,
):
hidden_states += residual
residual = None
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
attn_tp_all_gather(
list(hidden_states.tensor_split(context.attn_tp_size)),
local_hidden_states,
)
return hidden_states, residual
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment