Unverified Commit 0e25fcc4 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

Cherry-pick Megatron-LM's changes in pipeline model parallel for T5 (#1232)

* update parallel_state

* update pipeline common funcs - forward_step and backward_step

* update pipelining w/o interleaving

* type hint

* merge utils into without_interleaving

Motivation: functions in utils are only used by
forward_backward_pipelining_without_interleaving

* fix handling of `model_type`

* fix import of DDP

* update set_input_tensor method

* fix

* cosmetic

* update model

* refactor pipeline test scripts
parent ab7af058
......@@ -28,3 +28,8 @@ class AttnType(enum.Enum):
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
......@@ -13,12 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model and data parallel groups."""
from typing import Tuple
from typing import Tuple, Optional
import torch
# TODO (mkozuki): Consider dissecting utils as this utils import is here
# only for ensure_divisibility
from apex.transformer.utils import ensure_divisibility
......@@ -35,6 +33,7 @@ _DATA_PARALLEL_GROUP = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
......@@ -56,14 +55,19 @@ def is_unitialized():
def initialize_model_parallel(
tensor_model_parallel_size_=1, pipeline_model_parallel_size_=1, virtual_pipeline_model_parallel_size_=None
):
tensor_model_parallel_size_: int = 1,
pipeline_model_parallel_size_: int = 1,
virtual_pipeline_model_parallel_size_: Optional[int] = None,
pipeline_model_parallel_split_rank_: Optional[int] = None,
) -> None:
"""
Initialize model data parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used to parallelize model tensor.
pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved pipeline).
pipeline_model_parallel_split_rank: for models with both encoder and decoder, rank in pipeline with split point.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
......@@ -106,6 +110,10 @@ def initialize_model_parallel(
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_
if pipeline_model_parallel_split_rank_ is not None:
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank_
rank = torch.distributed.get_rank()
# Build the data-parallel groups.
......@@ -231,6 +239,44 @@ def is_rank_in_embedding_group(ignore_virtual=False):
return False
def is_pipeline_stage_before_split(rank=None):
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_after_split(rank=None):
"""Return True if pipeline stage executes decoder block for a model
with both encoder and decoder."""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_at_split():
"""Return true if pipeline stage executes decoder block and next
stage executes encoder block for a model with both encoder and
decoder."""
rank = get_pipeline_model_parallel_rank()
return is_pipeline_stage_before_split(rank) and is_pipeline_stage_after_split(rank + 1)
def set_tensor_model_parallel_world_size(world_size):
"""Set the tensor model parallel size"""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
......
......@@ -190,8 +190,8 @@ def recv_forward(
"""Receive tensor from previous rank in pipeline (forward receive)."""
if parallel_state.is_pipeline_first_stage():
return None
if timers is not None:
timers("forward-recv").start()
# if timers is not None:
# timers("forward-recv").start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
......@@ -201,8 +201,8 @@ def recv_forward(
override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("forward-recv").stop()
# if timers is not None:
# timers("forward-recv").stop()
return input_tensor
......@@ -211,12 +211,12 @@ def recv_backward(
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
):
) -> torch.Tensor:
"""Receive tensor from next rank in pipeline (backward receive)."""
if parallel_state.is_pipeline_last_stage():
return None
if timers is not None:
timers("backward-recv").start()
# if timers is not None:
# timers("backward-recv").start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
......@@ -225,8 +225,8 @@ def recv_backward(
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("backward-recv").stop()
# if timers is not None:
# timers("backward-recv").stop()
return output_tensor_grad
......@@ -241,8 +241,8 @@ def send_forward(
"""Send tensor to next rank in pipeline (forward send)."""
if parallel_state.is_pipeline_last_stage():
return
if timers is not None:
timers("forward-send").start()
# if timers is not None:
# timers("forward-send").start()
_communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
......@@ -252,8 +252,8 @@ def send_forward(
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("forward-send").stop()
# if timers is not None:
# timers("forward-send").stop()
def send_backward(
......@@ -266,8 +266,8 @@ def send_backward(
"""Send tensor to previous rank in pipeline (backward send)."""
if parallel_state.is_pipeline_first_stage():
return
if timers is not None:
timers("backward-send").start()
# if timers is not None:
# timers("backward-send").start()
_communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
......@@ -276,8 +276,8 @@ def send_backward(
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("backward-send").stop()
# if timers is not None:
# timers("backward-send").stop()
def send_forward_recv_backward(
......@@ -286,12 +286,12 @@ def send_forward_recv_backward(
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> None:
) -> Union[None, torch.Tensor]:
"""Batched send and recv with next rank in pipeline."""
if parallel_state.is_pipeline_last_stage():
return None
if timers is not None:
timers("forward-send-backward-recv").start()
# if timers is not None:
# timers("forward-send-backward-recv").start()
_, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
......@@ -300,8 +300,8 @@ def send_forward_recv_backward(
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("forward-send-backward-recv").stop()
# if timers is not None:
# timers("forward-send-backward-recv").stop()
return output_tensor_grad
......@@ -311,12 +311,12 @@ def send_backward_recv_forward(
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> torch.Tensor:
) -> Union[None, torch.Tensor]:
"""Batched send and recv with previous rank in pipeline."""
if parallel_state.is_pipeline_first_stage():
return None
if timers is not None:
timers("backward-send-forward-recv").start()
# if timers is not None:
# timers("backward-send-forward-recv").start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
......@@ -325,8 +325,8 @@ def send_backward_recv_forward(
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("backward-send-forward-recv").stop()
# if timers is not None:
# timers("backward-send-forward-recv").stop()
return input_tensor
......@@ -339,8 +339,8 @@ def send_forward_recv_forward(
timers: _Timers = None,
) -> torch.Tensor:
"""Batched recv from previous rank and send to next rank in pipeline."""
if timers is not None:
timers("forward-send-forward-recv").start()
# if timers is not None:
# timers("forward-send-forward-recv").start()
input_tensor, _ = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
......@@ -349,8 +349,8 @@ def send_forward_recv_forward(
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("forward-send-forward-recv").stop()
# if timers is not None:
# timers("forward-send-forward-recv").stop()
return input_tensor
......@@ -363,8 +363,8 @@ def send_backward_recv_backward(
timers: _Timers = None,
) -> torch.Tensor:
"""Batched recv from next rank and send to previous rank in pipeline."""
if timers is not None:
timers("backward-send-backward-recv").start()
# if timers is not None:
# timers("backward-send-backward-recv").start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
......@@ -373,8 +373,8 @@ def send_backward_recv_backward(
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("backward-send-backward-recv").stop()
# if timers is not None:
# timers("backward-send-backward-recv").stop()
return output_tensor_grad
......@@ -387,10 +387,10 @@ def send_forward_backward_recv_forward_backward(
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
):
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Batched send and recv with previous and next ranks in pipeline."""
if timers is not None:
timers("forward-backward-send-forward-backward-recv").start()
# if timers is not None:
# timers("forward-backward-send-forward-backward-recv").start()
input_tensor, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
......@@ -399,6 +399,6 @@ def send_forward_backward_recv_forward_backward(
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("forward-backward-send-forward-backward-recv").stop()
# if timers is not None:
# timers("forward-backward-send-forward-backward-recv").stop()
return input_tensor, output_tensor_grad
# NOTE (mkozuki): For simplicity, tentatively `timers` related operations are commented out.
from typing import Any, Callable, Dict, List, Tuple, Union, Optional
# NOTE(mkozuki): For simplicity, tentatively `timers` related operations are commented out.
from typing import Any, Callable, Dict, List, Tuple, Union, Optional, Sequence
import torch
from apex.transformer import parallel_state
from apex.transformer.enums import ModelType
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import unwrap_model
from apex.transformer.pipeline_parallel.utils import get_model_type
from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes
......@@ -19,8 +21,8 @@ def build_model(
model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module],
wrap_with_ddp: bool = True,
virtual_pipeline_model_parallel_size: Optional[int] = None,
*args,
**kwargs
*args: Any,
**kwargs: Any,
) -> List[torch.nn.Module]:
"""Build the model satisfying pipeline model parallel requirements.
......@@ -110,6 +112,7 @@ def _get_params_for_weight_decay_optimization(
model: Union[torch.nn.Module, List[torch.nn.Module]],
) -> Dict[str, torch.nn.Parameter]:
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and biases will have no weight decay but the rest will.
"""
modules = listify_model(model)
......@@ -137,13 +140,12 @@ def forward_step(
forward_step_func: FwdStepFunc,
batch: Batch,
model: torch.nn.Module,
input_tensor: Optional[torch.Tensor],
input_tensor: Optional[Union[torch.Tensor, List[torch.Tensor]]],
losses_reduced: List[torch.Tensor],
):
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
If first stage, input tensor is obtained from batch, otherwise passed-in input_tensor is used.
Returns output tensor.
......@@ -161,12 +163,16 @@ def forward_step(
# timers = get_timers()
# timers("forward-compute").start()
unwrapped_model = unwrap_model(model)
model_type = get_model_type(unwrapped_model)
# NOTE (mkozuki): The passed `model` is expected to implement `set_input_tensor`.
# See https://github.com/NVIDIA/Megatron-LM/blob/5ac5571ba0265af4c491ee0af1508ca7589450c6/megatron/model/transformer.py#L679 # NOQA
# for the details of `set_input_tensor`.
unwrap_output_tensor = not isinstance(input_tensor, list)
if unwrap_output_tensor:
input_tensor = [input_tensor]
unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(batch, model)
# print(f"forward_step| pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()} is_pipeline_last_stage?: {parallel_state.is_pipeline_last_stage()}")
if parallel_state.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor
......@@ -174,14 +180,22 @@ def forward_step(
losses_reduced.append(loss_reduced)
# timers("forward-compute").stop()
return output_tensor
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
if parallel_state.is_pipeline_stage_after_split() and model_type == ModelType.encoder_and_decoder:
return [output_tensor, input_tensor[-1]]
if unwrap_output_tensor:
return output_tensor
return [output_tensor]
def backward_step(
input_tensor: Optional[torch.Tensor],
output_tensor: torch.Tensor,
output_tensor_grad: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
model_type: ModelType,
) -> Union[None, torch.Tensor, Sequence[torch.Tensor]]:
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
......@@ -200,19 +214,39 @@ def backward_step(
# timers = get_timers()
# timers("backward-compute").start()
# Retain the grad on the input_tensor.
unwrap_input_tensor_grad = not isinstance(input_tensor, list)
if unwrap_input_tensor_grad:
input_tensor = [input_tensor]
for x in input_tensor:
if x is not None:
x.retain_grad()
if not isinstance(output_tensor, list):
output_tensor = [output_tensor]
if not isinstance(output_tensor_grad, list):
output_tensor_grad = [output_tensor_grad]
# if parallel_state.get_pipeline_model_parallel_rank() == 0:
# print(f"{input_tensor}, {output_tensor}, {output_tensor_grad}")
if input_tensor is not None:
input_tensor.retain_grad()
# Backward pass.
# if output_tensor_grad is None:
# output_tensor = optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
input_tensor_grad = None
torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
# Collect the grad of the input_tensor.
input_tensor_grad = [None]
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
# timers("backward-compute").stop()
input_tensor_grad = []
for x in input_tensor:
input_tensor_grad.append(None if x is None else x.grad)
# Handle single skip connection if it exists (encoder_hidden_state in model with encoder and decoder).
if (
parallel_state.get_pipeline_model_parallel_world_size() > 1 and
parallel_state.is_pipeline_stage_after_split() and
model_type == ModelType.encoder_and_decoder
):
if output_tensor_grad[1] is not None:
# todo (mkozuki): Replace the inplace add with `+= output_tensor_grad[1]`?
input_tensor_grad[-1].add_(output_tensor_grad[1])
return input_tensor_grad
# timers("backward-compute").stop()
return input_tensor_grad[0] if unwrap_input_tensor_grad else input_tensor_grad
......@@ -6,7 +6,9 @@ import torch
from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.schedules.common import Batch, FwdStepFunc
from apex.transformer.pipeline_parallel.utils import get_model_type
from apex.transformer.pipeline_parallel.schedules.common import Batch
from apex.transformer.pipeline_parallel.schedules.common import FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.log_util import get_transformer_logger
......@@ -58,6 +60,7 @@ def forward_backward_no_pipelining(
msg = f"`model` is expected be a `nn.Module`, but {type(model)}"
raise RuntimeError(msg)
model = model[0]
model_type = get_model_type(model)
context_handler = placeholder_handler
if isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel):
......@@ -75,7 +78,7 @@ def forward_backward_no_pipelining(
forward_step_func, cur_micro_batch, model, input_tensor, losses_reduced)
if not forward_only:
_logger.debug("Call `backward_step`")
backward_step(input_tensor, output_tensor, output_tensor_grad)
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
......@@ -86,6 +89,6 @@ def forward_backward_no_pipelining(
)
if not forward_only:
_logger.debug("Call `backward_step`")
backward_step(input_tensor, output_tensor, output_tensor_grad)
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type)
return losses_reduced
from typing import List, Union, Optional
from typing import List, Union, Optional, Sequence
import torch
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import p2p_communication
from apex.transformer.pipeline_parallel.schedules.common import Batch, FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import Batch
from apex.transformer.pipeline_parallel.schedules.common import FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import get_model_type
from apex.transformer.log_util import get_transformer_logger
......@@ -18,7 +20,7 @@ __all__ = ["_forward_backward_pipelining_with_interleaving"]
_logger = get_transformer_logger(__name__)
# TODO (mkozuki): Reduce cyclomatic complexity
# TODO(mkozuki): Reduce cyclomatic complexity
def _forward_backward_pipelining_with_interleaving(
forward_step_func: FwdStepFunc,
batch: List[Batch],
......@@ -26,7 +28,7 @@ def _forward_backward_pipelining_with_interleaving(
*,
forward_only: bool,
tensor_shape: Optional[Union[List[int], torch.Size]] = None,
):
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run interleaved 1F1B schedule with communication between pipeline stages as needed.
This function assumes `batch` and `model` is a list of `Batch`'s and a list of `torch.nn.Module`, respectively.
......@@ -56,22 +58,22 @@ def _forward_backward_pipelining_with_interleaving(
if not isinstance(model, list):
raise RuntimeError("`model` must be a list of `nn.Module`'s'")
num_model_chunks = len(model)
input_tensors = [[] for _ in range(num_model_chunks)]
output_tensors = [[] for _ in range(num_model_chunks)]
curr_iters = [0 for _ in range(num_model_chunks)]
losses_reduced = []
num_model_chunks: int = len(model)
input_tensors: List[List[Union[None, torch.Tensor]]] = [[] for _ in range(num_model_chunks)]
output_tensors: List[List[Union[None, torch.Tensor]]] = [[] for _ in range(num_model_chunks)]
curr_iters: List[int] = [0 for _ in range(num_model_chunks)]
losses_reduced: List[Union[None, torch.Tensor]] = []
if not forward_only:
output_tensor_grads = [[] for _ in range(num_model_chunks)]
output_tensor_grads: List[List[Union[None, torch.Tensor]]] = [[] for _ in range(num_model_chunks)]
pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()
pipeline_parallel_size: int = parallel_state.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank: int = parallel_state.get_pipeline_model_parallel_rank()
# Compute number of warmup and remaining microbatches.
num_microbatches = get_num_microbatches() * num_model_chunks
all_warmup_microbatches = False
num_microbatches: int = get_num_microbatches() * num_model_chunks
all_warmup_microbatches: bool = False
if forward_only:
num_warmup_microbatches = num_microbatches
num_warmup_microbatches: int = num_microbatches
else:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
......@@ -86,7 +88,7 @@ def _forward_backward_pipelining_with_interleaving(
num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches
_logger.info(
f"num_microbatches: {num_microbatches}, "
......@@ -106,10 +108,11 @@ def _forward_backward_pipelining_with_interleaving(
model_chunk_id = num_model_chunks - model_chunk_id - 1
return model_chunk_id
def forward_step_helper(microbatch_id, curr_iters):
def forward_step_helper(microbatch_id: int, curr_iters: List[int]) -> torch.Tensor:
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
(run set_virtual_pipeline_model_parallel_rank() before calling forward_step()).
"""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
......@@ -137,11 +140,13 @@ def _forward_backward_pipelining_with_interleaving(
return output_tensor
def backward_step_helper(microbatch_id):
def backward_step_helper(microbatch_id: int) -> torch.Tensor:
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
(run set_virtual_pipeline_model_parallel_rank() before calling backward_step()).
"""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
model_type = get_model_type(model[model_chunk_id])
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if parallel_state.is_pipeline_last_stage():
......@@ -150,7 +155,7 @@ def _forward_backward_pipelining_with_interleaving(
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type)
return input_tensor_grad
......@@ -200,7 +205,8 @@ def _forward_backward_pipelining_with_interleaving(
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
else:
_logger.debug("send fwd and receive fwd")
input_tensor = p2p_communication.send_forward_recv_forward(output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape)
input_tensor = p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
###################################################################################################################
......@@ -302,7 +308,8 @@ def _forward_backward_pipelining_with_interleaving(
if k == (num_microbatches - 1):
recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward(input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape)
p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape)
)
return losses_reduced
from typing import Union, List, Optional
from typing import Union, List, Optional, Sequence
import torch
from apex.transformer import parallel_state
from apex.transformer.enums import ModelType
from apex.transformer.pipeline_parallel import p2p_communication
from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.schedules.common import Batch, FwdStepFunc
from apex.transformer.pipeline_parallel.utils import get_model_type
from apex.transformer.pipeline_parallel.schedules.common import Batch
from apex.transformer.pipeline_parallel.schedules.common import FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.log_util import get_transformer_logger
......@@ -19,14 +22,126 @@ __all__ = ["forward_backward_pipelining_without_interleaving"]
_logger = get_transformer_logger(__name__)
def get_tensor_shapes(
rank: int,
model_type: ModelType,
*,
tensor_shape: Union[List[int], torch.Size],
decoder_sequence_length: Optional[int] = None,
) -> Sequence[Sequence[int]]:
# Determine right tensor sizes (based on position of rank with respect to split
# rank) and model size.
# Send two tensors if model is T5 and rank is in decoder stage:
# first tensor is decoder (pre-transpose),
# second tensor is encoder (post-transpose).
# If model is T5 and rank is at the boundary:
# send one tensor (post-transpose from encoder).
# Otherwise, send one tensor (pre-transpose).
assert (
len(tensor_shape) == 3
), f"`tensor_shape` should be [sequence_length, micro_batch_size, hidden_size] but {tensor_shape}"
sequence_length, micro_batch_size, hidden_size = tensor_shape
tensor_shapes = []
if model_type == ModelType.encoder_and_decoder:
if decoder_sequence_length is None:
raise ValueError("`decoder_sequence_length` is required for `ModelType.encoder_and_decoder`")
if parallel_state.is_pipeline_stage_before_split(rank):
# If next rank is after split, then need transpose for encoder_hidden_state.
if parallel_state.is_pipeline_stage_before_split(rank + 1):
tensor_shapes.append((sequence_length, micro_batch_size, hidden_size))
else:
tensor_shapes.append((micro_batch_size, sequence_length, hidden_size))
else:
tensor_shapes.append((decoder_sequence_length, micro_batch_size, hidden_size))
tensor_shapes.append((micro_batch_size, sequence_length, hidden_size))
else:
tensor_shapes.append((sequence_length, micro_batch_size, hidden_size))
return tensor_shapes
def recv_forward(tensor_shapes: List[Union[None, List[int]]],) -> List[Union[None, torch.Tensor]]:
input_tensors = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
input_tensors.append(None)
else:
input_tensors.append(p2p_communication.recv_forward(tensor_shape=tensor_shape))
return input_tensors
def recv_backward(tensor_shapes: List[Union[None, List[int]]],) -> List[Union[None, torch.Tensor]]:
output_tensor_grads = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
output_tensor_grads.append(None)
else:
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape=tensor_shape))
return output_tensor_grads
def send_forward(
output_tensors: Union[torch.Tensor, List[Union[None, torch.Tensor]]], tensor_shapes: List[Union[None, List[int]]],
) -> None:
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape)
def send_backward(
input_tensor_grads: Union[torch.Tensor, List[Union[None, torch.Tensor]]],
tensor_shapes: List[Union[None, List[int]]],
) -> None:
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape)
def send_forward_recv_backward(
output_tensors: Union[torch.Tensor, List[Union[None, torch.Tensor]]], tensor_shapes: List[Union[None, List[int]]],
) -> List[Union[None, torch.Tensor]]:
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
output_tensor_grads = []
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
output_tensor_grads.append(None)
continue
output_tensor_grad = p2p_communication.send_forward_recv_backward(output_tensor, tensor_shape=tensor_shape)
output_tensor_grads.append(output_tensor_grad)
return output_tensor_grads
def send_backward_recv_forward(
input_tensor_grads: Union[torch.Tensor, List[Union[None, torch.Tensor]]],
tensor_shapes: List[Union[None, List[int]]],
) -> List[Union[None, torch.Tensor]]:
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
input_tensors = []
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
input_tensors.append(None)
continue
input_tensor = p2p_communication.send_backward_recv_forward(input_tensor_grad, tensor_shape=tensor_shape)
input_tensors.append(input_tensor)
return input_tensors
def forward_backward_pipelining_without_interleaving(
forward_step_func: FwdStepFunc,
batch: Batch,
model: Union[torch.nn.Module, List[torch.nn.Module]],
*,
forward_only: bool,
tensor_shape: Optional[Union[List[int], torch.Size]] = None,
):
forward_step_func: FwdStepFunc,
batch: Batch,
model: Union[torch.nn.Module, List[torch.nn.Module]],
*,
forward_only: bool,
tensor_shape: Optional[Union[List[int], torch.Size]] = None,
decoder_sequence_length: Optional[int] = None,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run non-interleaved 1F1B schedule, with communication between pipeline stages.
This pipeline parallel scheduling consists of three steps:
......@@ -51,21 +166,28 @@ def forward_backward_pipelining_without_interleaving(
"""
# timers = get_timers()
model = listify_model(model)
model: List[torch.nn.Module] = listify_model(model)
if len(model) != 1:
msg = f"`model` is expected be a `nn.Module`, but {type(model)}"
raise RuntimeError(msg)
model = model[0]
model: torch.nn.Module = model[0]
# Compute number of warmup microbatches.
num_microbatches = get_num_microbatches()
num_warmup_microbatches = (
parallel_state.get_pipeline_model_parallel_world_size()
- parallel_state.get_pipeline_model_parallel_rank()
- 1
num_microbatches: int = get_num_microbatches()
num_warmup_microbatches: int = (
parallel_state.get_pipeline_model_parallel_world_size() - parallel_state.get_pipeline_model_parallel_rank() - 1
)
num_warmup_microbatches: int = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches
model_type = get_model_type(model)
rank: int = parallel_state.get_pipeline_model_parallel_rank()
recv_tensor_shapes: List[List[int]] = get_tensor_shapes(
rank - 1, model_type, tensor_shape=tensor_shape, decoder_sequence_length=decoder_sequence_length
)
send_tensor_shapes: List[List[int]] = get_tensor_shapes(
rank, model_type, tensor_shape=tensor_shape, decoder_sequence_length=decoder_sequence_length
)
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
_logger.info(
f"num_microbatches: {num_microbatches}, "
......@@ -74,13 +196,9 @@ def forward_backward_pipelining_without_interleaving(
)
# Input, output tensors only need to be saved when doing backward passes
input_tensors = None
output_tensors = None
if not forward_only:
input_tensors = []
output_tensors = []
losses_reduced = []
input_tensors: List[Union[None, torch.Tensor]] = []
output_tensors: List[Union[None, torch.Tensor]] = []
losses_reduced: List[Union[None, torch.Tensor]] = []
###################################################################################################################
# Run warmup forward passes.
###################################################################################################################
......@@ -88,11 +206,11 @@ def forward_backward_pipelining_without_interleaving(
for i in range(num_warmup_microbatches):
_logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}")
_logger.debug("receive fwd")
input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape)
input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes)
cur_microbatch = get_kth_microbatch(batch, i)
output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced)
_logger.debug("send fwd")
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape)
send_forward(output_tensor, tensor_shapes=send_tensor_shapes)
if not forward_only:
input_tensors.append(input_tensor)
......@@ -103,7 +221,7 @@ def forward_backward_pipelining_without_interleaving(
# receive this tensor here.
if num_microbatches_remaining > 0:
_logger.debug("recv_forward before steady state start")
input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape)
input_tensor: List[Union[None, torch.Tensor]] = recv_forward(tensor_shapes=recv_tensor_shapes)
###################################################################################################################
# Run 1F1B in steady state.
......@@ -111,21 +229,23 @@ def forward_backward_pipelining_without_interleaving(
_logger.info("Steady phase")
for i in range(num_microbatches_remaining):
_logger.debug(f"steady iter: {i} / {num_microbatches_remaining}")
last_iteration = i == (num_microbatches_remaining - 1)
last_iteration: bool = i == (num_microbatches_remaining - 1)
cur_microbatch = get_kth_microbatch(batch, i + num_warmup_microbatches)
output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced)
cur_microbatch: torch.Tensor = get_kth_microbatch(batch, i + num_warmup_microbatches)
output_tensor: Union[torch.Tensor, Sequence[torch.Tensor]] = forward_step(
forward_step_func, cur_microbatch, model, input_tensor, losses_reduced
)
if forward_only:
_logger.debug("send fwd")
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape)
send_forward(output_tensor, tensor_shapes=send_tensor_shapes)
if not last_iteration:
_logger.debug("receive fwd (last iteration)")
input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape)
input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes)
else:
_logger.debug("send fwd & receive bwd")
output_tensor_grad = p2p_communication.send_forward_recv_backward(output_tensor, tensor_shape=tensor_shape)
output_tensor_grad = send_forward_recv_backward(output_tensor, tensor_shapes=send_tensor_shapes)
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
......@@ -135,18 +255,15 @@ def forward_backward_pipelining_without_interleaving(
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad
)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type)
if last_iteration:
input_tensor = None
_logger.debug("send bwd")
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape)
send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes)
else:
_logger.debug("send bwd and receive fwd")
input_tensor = p2p_communication.send_backward_recv_forward(
input_tensor_grad, tensor_shape=tensor_shape)
input_tensor = send_backward_recv_forward(input_tensor_grad, tensor_shapes=recv_tensor_shapes)
###################################################################################################################
# Run cooldown backward passes.
###################################################################################################################
......@@ -158,13 +275,11 @@ def forward_backward_pipelining_without_interleaving(
output_tensor = output_tensors.pop(0)
_logger.debug("receive bwd")
output_tensor_grad = p2p_communication.recv_backward(tensor_shape=tensor_shape)
output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes)
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad
)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type)
_logger.debug("send bwd")
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape)
send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes)
return losses_reduced
......@@ -21,6 +21,7 @@ from torch.nn.parallel import DistributedDataParallel
from apex.multi_tensor_apply import multi_tensor_applier
from apex.transformer import parallel_state
from apex.transformer.enums import ModelType
from apex.transformer.microbatches import build_num_microbatches_calculator
from apex.transformer.pipeline_parallel._timers import _Timers
if multi_tensor_applier.available:
......@@ -186,6 +187,19 @@ def unwrap_model(model, module_instances=(DistributedDataParallel,)):
return unwrapped_model
def get_model_type(
model: torch.nn.Module,
) -> ModelType:
"""Get `model_type` of `model`.
If ``model`` doesn't have ``model_type`` attribute, return ``ModelType.encoder_or_decoder``.
Args:
model
"""
return getattr(unwrap_model(model), "model_type", ModelType.encoder_or_decoder)
def calc_params_l2_norm(model: torch.nn.Module, bf16: bool):
"""Calculate l2 norm of parameters """
# args = get_args()
......
......@@ -21,6 +21,7 @@ import torch
import torch.nn as nn
from apex import transformer
from apex.transformer.pipeline_parallel.utils import average_losses_across_data_parallel_group
from apex.transformer.testing import global_vars
......@@ -49,7 +50,9 @@ class MyModel(nn.Module):
self.input_tensor = None
def set_input_tensor(self, input_tensor: Union[torch.Tensor, List[torch.Tensor]]) -> None:
self.input_tensor = input_tensor
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
self.input_tensor = input_tensor[0]
def forward(self, x: Optional[torch.Tensor]) -> torch.Tensor:
if self.input_tensor is None:
......@@ -61,6 +64,27 @@ def model_provider_func(hidden_size, pre_process, post_process) -> MyModel:
return MyModel(hidden_size, pre_process, post_process)
def process_batch(batch):
if isinstance(batch, list):
x = batch[0]
else:
x = batch
return x
def fwd_step_func(batch, model):
x = process_batch(batch)
y = model(x)
# note (mkozuki): I don't think this function is nice but I do think this is enough for now
# just to check the sanity of ported pipeline functions.
def loss_func(x):
loss = torch.sum(x)
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'avg': averaged_loss}
return y, loss_func
class IdentityLayer(torch.nn.Module):
def __init__(self, size, scale=1.0):
super(IdentityLayer, self).__init__()
......
......@@ -11,7 +11,6 @@ from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import (
_forward_backward_pipelining_with_interleaving,
)
from apex.transformer.pipeline_parallel.utils import average_losses_across_data_parallel_group
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import update_num_microbatches
......@@ -19,6 +18,7 @@ from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.commons import fwd_step_func
from apex.transformer.log_util import get_transformer_logger, set_logging_level
from apex.transformer.testing.commons import model_provider_func
from apex.transformer._data import MegatronPretrainingRandomSampler
......@@ -44,29 +44,7 @@ HIDDEN_SIZE = 16
def Dataset(num_samples: int) -> List[Tuple[torch.Tensor, torch.Tensor]]:
return [(torch.randn(HIDDEN_SIZE), torch.randn(HIDDEN_SIZE // 2)) for _ in range(num_samples)]
def process_batch(batch):
if isinstance(batch, (list, tuple)):
x = batch[0]
else:
x = batch
return x
def fwd_step_func(micro_batch, model):
x = process_batch(micro_batch)
y = model(x)
# note (mkozuki): I don't think this function is nice but I do think this is enough for now
# just to check the sanity of ported pipeline functions.
def loss_func(x):
loss = torch.sum(x)
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"avg": averaged_loss}
return y, loss_func
return [(torch.randn(HIDDEN_SIZE, HIDDEN_SIZE), torch.randn(HIDDEN_SIZE // 2, HIDDEN_SIZE // 2)) for _ in range(num_samples)]
# Run forward & backward with dynamic batch size.
......@@ -122,7 +100,7 @@ def run_interleaved_with_dynamic_batch_size(
assert isinstance(batch, (list, tuple))
return [get_num_samples(b) for b in batch]
tensor_shape = [micro_batch_size, HIDDEN_SIZE]
tensor_shape = [micro_batch_size, HIDDEN_SIZE, HIDDEN_SIZE]
consumed_samples = 0
for i in range(NUM_ITERATIONS):
update_num_microbatches(consumed_samples, consistency_check=False)
......
from typing import Optional, Union, List
from typing import Optional
import torch
import torch.nn as nn
import apex
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import get_forward_backward_func
from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization
......@@ -11,13 +9,14 @@ from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import _forward_backward_pipelining_with_interleaving
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import forward_backward_pipelining_without_interleaving
from apex.transformer.pipeline_parallel.utils import average_losses_across_data_parallel_group
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import update_num_microbatches
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.commons import model_provider_func
from apex.transformer.testing.commons import fwd_step_func
from apex.transformer.log_util import get_transformer_logger, set_logging_level
......@@ -35,62 +34,6 @@ fwd_bwd_functions = {
}
# note (mkozuki): `pre_process` and `post_process` are a placeholder until interleaving schedule test comes.
class MyLayer(nn.Module):
def __init__(self, pre_process: bool, post_process: bool):
super().__init__()
self.pre_process = pre_process
self.post_process = post_process
self.layer = nn.Linear(hidden_size, hidden_size)
def forward(self, x):
return self.layer(x)
class MyModel(nn.Module):
def __init__(self, pre_process: bool = False, post_process: bool = False) -> None:
super().__init__()
self.pre_process = pre_process
self.post_process = post_process
self.layer = MyLayer(pre_process=pre_process, post_process=post_process)
self.input_tensor = None
def set_input_tensor(self, input_tensor: Union[torch.Tensor, List[torch.Tensor]]) -> None:
self.input_tensor = input_tensor
def forward(self, x: Optional[torch.Tensor]) -> torch.Tensor:
if self.input_tensor is None:
return self.layer(x)
return self.layer(self.input_tensor)
def model_provider_func(pre_process, post_process) -> MyModel:
return MyModel(pre_process, post_process)
def process_batch(batch):
if isinstance(batch, list):
x = batch[0]
else:
x = batch
return x
def fwd_step_func(batch, model):
x = process_batch(batch)
y = model(x)
# note (mkozuki): I don't think this function is nice but I do think this is enough for now
# just to check the sanity of ported pipeline functions.
def loss_func(x):
loss = torch.sum(x)
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'avg': averaged_loss}
return y, loss_func
# TODO (mkozuki): Add a case with `autocast` and `GradScaler`.
# Run forward & backward for one minibatch.
def forward_backward_func_template(
......@@ -121,13 +64,14 @@ def forward_backward_func_template(
model_provider_func,
wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
hidden_size=hidden_size,
)
assert isinstance(model, list)
assert len(model) == (1 if virtual_pipeline_model_parallel_size is None else virtual_pipeline_model_parallel_size)
_param_groups = _get_params_for_weight_decay_optimization(model)
torch.optim.Adam(_param_groups, lr=1e-4)
tensor_shape = [batch_size // parallel_state.get_data_parallel_world_size(), hidden_size]
tensor_shape = [batch_size // parallel_state.get_data_parallel_world_size(), hidden_size, hidden_size]
batch = (torch.randn(tensor_shape).cuda(),)
tensor_shape[0] = micro_batch_size
......@@ -183,6 +127,7 @@ if __name__ == "__main__":
f"virtual pipeline rank: {parallel_state.get_virtual_pipeline_model_parallel_rank()}\n"
f"{str(e)}"
)
print(failures[-1])
finally:
parallel_state.destroy_model_parallel()
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment