Unverified Commit 63d5dd63 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

Pipeline Model Parallel (#1202)

* Init apex.ppu (pipeline model parallel utility)

Reference commit:

```
commit 5ab646376d67831601d5552c193241d017f1b35c (HEAD -> main, internal/main)
Merge: 14f2c684 7b293d9b
Author: Mohammad Shoeybi <mshoeybi@nvidia.com>
Date:   Wed Sep 22 22:57:54 2021 -0700

    Merge branch 'add_BOS' into 'main'

    Add Beginning of Sentence token option and adding semaphore while multi-threading to prevent crashes and hangs due to connection keep-alives

    See merge request ADLR/megatron-lm!328
```

* removing get_args and replace import - phase 1

* removing get_args and replace import - phase 2

* move ppu to apex.transformer.pipeline_parallel

* update two __init__.py

* update READMEs

* mpu -> parallel_state & tensor_parallel

* fix

* remove not pipeline files

* separate schedules.py - phase 1

* dissect schedules.py

* data_iterators -> batch

* remove optimizer from forward_backward_step funcs

* init test

* Apply 2 suggestion(s...
parent 3303b3e7
# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten
import torch
import warnings
if torch.distributed.is_available():
from . import parallel
......
from typing import Optional
import torch
def _get_current_dtype(dtype: Optional[torch.dtype] = None) -> torch.dtype:
if not torch.is_autocast_enabled():
return torch.float or dtype
else:
return torch.get_autocast_gpu_dtype()
def _cast_if_autocast_enabled(*args):
if not torch.is_autocast_enabled():
return args
......
......@@ -2,4 +2,80 @@
`apex.transformer` is a module which enables efficient large Transformer models at scale.
`apex.transformer.tensor_parallel` is based on [NVIDIA/Megatron-LM](https://github.com/NVIDIA/Megatron-LM)'s `megatron.mpu` module.
`apex.transformer.tensor_parallel` and `apex.transformer.pipeline_parallel` are both based on [NVIDIA/Megatron-LM](https://github.com/NVIDIA/Megatron-LM)'s module.
The former is based on `megatron.mpu` and the latter is on `megatron.schedules` and `megatron.p2p_communication`.
## Tensor Model Parallel (TP)
APEX's tensor model parallel utilities provides some `torch.nn.Module`'s, custom fused kernels, and PRNG state handling.
See Appendix B.2 of [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) for the details of
PRNG state handling.
## Pipeline Model Parallel (PP)
APEX's pipeline model parallel functions require models to have `.set_input_tensor` because
the input tensor for `.forward` method can be `None`.
The following is a really casual sketch of training script with apex pp.
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import get_forward_backward_func
class Model(nn.Module):
...
def __init__(self, *args, **kwargs):
super().__init__()
pre_process = kwargs.pop("pre_process")
post_process = kwargs.pop("post_process")
def set_input_tensor(self, tensor):
self.input_tensor = tensor
def forward(self, x, ...):
if parallel_state.is_pipeline_first_stage():
input = x
else:
input = self.input_tensor
...
def model_provider_func(*args, **kwargs):
return Model(*args, **kwargs)
def loss_func(pred, label):
loss = ...
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'nice_loss': averaged_loss}
def forward_step_func(batch, model):
input, label = process_batch(batch)
out = model(input)
return out, partial(loss_func, label)
forward_backward_func = get_forward_backward_func(virtual_pipeline_model_parallel_size, pipeline_model_parallel_size)
parallel_state.initialize_model_parallel(
tensor_model_parallel_size,
pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size,
)
# The following line basically is equivalent to `build_model(Model, wrap_with_ddp, virtual_pipeline_model_parallel_size, *model_args, **model_kwargs)`
model = build_model(model_provider_func, wrap_with_ddp, virtual_pipeline_model_parallel_size, *model_args, **model_kwargs)
optimizer = ...
data_loader = ...
for epoch in range(num_epochs):
for batch in data_loader:
forward_backward_func(forward_step_func, batch, model, forward_only=False, tensor_shape)
optimizer.step()
```
from . import tensor_parallel
from . import functional
from .enums import LayerType
from .enums import AttnType
from .enums import AttnMaskType
from .parallel_state import (
is_unitialized,
destroy_model_parallel,
get_data_parallel_group,
get_data_parallel_rank,
get_data_parallel_world_size,
get_embedding_group,
get_model_parallel_group,
get_tensor_model_parallel_group,
get_pipeline_model_parallel_group,
get_tensor_model_parallel_rank,
set_tensor_model_parallel_rank,
get_pipeline_model_parallel_rank,
set_pipeline_model_parallel_rank,
is_pipeline_first_stage,
is_pipeline_last_stage,
get_tensor_model_parallel_src_rank,
get_pipeline_model_parallel_first_rank,
get_pipeline_model_parallel_last_rank,
get_pipeline_model_parallel_next_rank,
get_pipeline_model_parallel_prev_rank,
get_tensor_model_parallel_world_size,
set_tensor_model_parallel_world_size,
get_pipeline_model_parallel_world_size,
set_pipeline_model_parallel_world_size,
get_virtual_pipeline_model_parallel_rank,
set_virtual_pipeline_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from apex.transformer import functional
from apex.transformer import parallel_state
from apex.transformer import pipeline_parallel
from apex.transformer import tensor_parallel
from apex.transformer import utils
from apex.transformer.enums import LayerType
from apex.transformer.enums import AttnType
from apex.transformer.enums import AttnMaskType
__all__ = [
"functional",
"parallel_state",
"pipeline_parallel",
"tensor_parallel",
"utils",
# enums.py
"LayerType",
"AttnType",
"AttnMaskType",
]
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
from .fused_softmax import FusedScaleMaskSoftmax
from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax
__all__ = [
"FusedScaleMaskSoftmax",
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -15,7 +15,7 @@
import torch
from apex._autocast_utils import _cast_if_autocast_enabled
from ..enums import AttnMaskType
from apex.transformer.enums import AttnMaskType
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -15,35 +15,41 @@
"""Megatron number of micro-batches calculators."""
from abc import ABC
from abc import abstractmethod
from typing import Optional, List
def build_num_microbatches_calculator(args):
def build_num_microbatches_calculator(
rank: int,
rampup_batch_size: Optional[List[int]],
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
):
# Constant num micro-batches.
if args.rampup_batch_size is None:
if rampup_batch_size is None:
num_microbatches_calculator = ConstantNumMicroBatches(
args.global_batch_size, args.micro_batch_size, args.data_parallel_size
global_batch_size, micro_batch_size, data_parallel_size
)
if args.rank == 0:
if rank == 0:
print(
"setting number of micro-batches to constant {}".format(num_microbatches_calculator.get()), flush=True
)
else:
assert len(args.rampup_batch_size) == 3, (
assert len(rampup_batch_size) == 3, (
"expected the following "
"format: --rampup-batch-size <start batch size> "
"<batch size incerement> <ramp-up samples>"
)
start_batch_size = int(args.rampup_batch_size[0])
batch_size_increment = int(args.rampup_batch_size[1])
ramup_samples = int(args.rampup_batch_size[2])
if args.rank == 0:
start_batch_size = int(rampup_batch_size[0])
batch_size_increment = int(rampup_batch_size[1])
ramup_samples = int(rampup_batch_size[2])
if rank == 0:
print(
"will use batch size rampup starting from global batch "
"size {} to global batch size {} with batch size increments "
"{} over {} samples.".format(
start_batch_size, args.global_batch_size, batch_size_increment, ramup_samples
start_batch_size, global_batch_size, batch_size_increment, ramup_samples
),
flush=True,
)
......@@ -51,9 +57,9 @@ def build_num_microbatches_calculator(args):
start_batch_size,
batch_size_increment,
ramup_samples,
args.global_batch_size,
args.micro_batch_size,
args.data_parallel_size,
global_batch_size,
micro_batch_size,
data_parallel_size,
)
return num_microbatches_calculator
......@@ -86,6 +92,8 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator):
assert self.num_micro_batches >= 1
self.current_global_batch_size = global_batch_size
self.micro_batch_size = micro_batch_size
def update(self, consumed_samples, consistency_check):
pass
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -17,7 +17,7 @@ import torch
# TODO (mkozuki): Consider dissecting utils as this utils import is here
# only for ensure_divisibility
from .tensor_parallel import utils
from apex.transformer.utils import ensure_divisibility
# Intra-layer model parallel group that the current rank belongs to.
......@@ -76,17 +76,17 @@ def initialize_model_parallel(
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
if torch.distributed.get_rank() == 0:
print("> initializing tensor model parallel with size {}".format(tensor_model_parallel_size_))
print("> initializing pipeline model parallel with size {}".format(pipeline_model_parallel_size_))
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)
# TODO (mkozuki): Consider moving `ensure_divisibility` to this file.
utils.ensure_divisibility(world_size, tensor_model_parallel_size * pipeline_model_parallel_size)
ensure_divisibility(world_size, tensor_model_parallel_size * pipeline_model_parallel_size)
data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)
if torch.distributed.get_rank() == 0:
print("> initializing tensor model parallel with size {}".format(tensor_model_parallel_size))
print("> initializing pipeline model parallel with size {}".format(pipeline_model_parallel_size))
print("> initializing data parallel with size {}".format(data_parallel_size))
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
......@@ -344,3 +344,15 @@ def destroy_model_parallel():
_DATA_PARALLEL_GROUP = None
global _EMBEDDING_GROUP
_EMBEDDING_GROUP = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
from apex.transformer.pipeline_parallel.schedules import get_forward_backward_func
from apex.transformer.pipeline_parallel.schedules.common import build_model
__all__ = [
"get_forward_backward_func",
"build_model",
]
import time
import torch
class _Timer:
"""Timer."""
def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
def start(self):
"""Start the timer."""
assert not self.started_, "timer has already been started"
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, "timer is not started"
torch.cuda.synchronize()
self.elapsed_ += time.time() - self.start_time
self.started_ = False
def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False
def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
return elapsed_
class _Timers:
"""Group of timers."""
def __init__(self):
self.timers = {}
def __call__(self, name):
if name not in self.timers:
self.timers[name] = _Timer(name)
return self.timers[name]
def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(name + "-time", value, iteration)
def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers."""
assert normalizer > 0.0
string = "time (ms)"
for name in names:
elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer
string += " | {}: {:.2f}".format(name, elapsed_time)
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1):
print(string, flush=True)
else:
print(string, flush=True)
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import reduce
import operator
from typing import Union, Optional, Tuple
import torch
from apex._autocast_utils import _get_current_dtype
from apex.transformer import parallel_state
from apex.transformer.utils import split_tensor_into_1d_equal_chunks
from apex.transformer.utils import gather_split_1d_tensor
from apex.transformer.pipeline_parallel.utils import Shape
from apex.transformer.pipeline_parallel._timers import _Timers
def _run_p2pops(
tensor_send_prev: Union[torch.Tensor, None],
tensor_send_next: Union[torch.Tensor, None],
tensor_recv_prev: Union[torch.Tensor, None],
tensor_recv_next: Union[torch.Tensor, None],
):
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor_send_prev,
parallel_state.get_pipeline_model_parallel_prev_rank(),
)
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_prev,
parallel_state.get_pipeline_model_parallel_prev_rank(),
)
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor_send_next,
parallel_state.get_pipeline_model_parallel_next_rank(),
)
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_next,
parallel_state.get_pipeline_model_parallel_next_rank(),
)
ops.append(recv_next_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# NOTE (mkozuki): Leaving `params_dytpe` as it is for future development in PyTorch, especially APEX O2 style AMP.
# But as of v1.10, basically all tensors are torch.float32 except for output tensors of `autocast` compatible layers.
def _communicate(
tensor_send_next: Optional[torch.Tensor],
tensor_send_prev: Optional[torch.Tensor],
recv_prev: bool,
recv_next: bool,
tensor_shape: Optional[Shape] = None,
override_scatter_gather_tensors_in_pipeline: bool = False,
dtype_: torch.dtype = torch.float,
*,
scatter_gather_tensors_in_pipeline: bool = True,
params_dtype: Optional[torch.dtype] = None,
fp32_residual_connection: bool = False,
) -> Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None]]:
"""Base function for communication of tensors between stages.
Args:
tensor_send_next: tensor to send to next rank (no tensor sent if set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if set to None).
recv_prev: boolean for whether tensor should be received from previous rank.
recv_next: boolean for whether tensor should be received from next rank.
tensor_shape: optional, use when the input sequence contains less tokens than the default sequence length
override_scatter_gather_tensors_in_pipeline:
optional, this is used when tensor_shape is provided to override scatter gather tensors
dtype_: This is used when tensor_shape is provided and what is the type of tensor_shape
Keyword args:
scatter_gather_tensors_in_pipeline: Optional. If :obj:`True`, use scatter/gather to optimize communication of tensors.
params_dtype: Optional and legacy. Defaults to torch.float. If you manually call `.half()` or `.bfloat16()` on
your model deliberately, pass this argument.
fp32_residual_connection: Optional. If :obj:`True`, move residual connections to fp32.
Returns:
tuple containing
- tensor_recv_prev: `torch.Tensor` if `recv_prev` is :obj:`True`, `None` otherwise.
- tensor_recv_next: `torch.Tensor` if `recv_next` is :obj:`True`, `None` otherwise.
"""
# Create placeholder tensors for receive in forward and backward directions if needed.
tensor_recv_prev = None
tensor_recv_next = None
if tensor_shape is None:
# In megatron, `tensor_shape` is set to `(args.seq_length, args.micro_batch_size, args.hidden_size)`
raise RuntimeError(
"`tensor_shape` must be specified. Common `tensor_shape` is `(seq_length, micro_batch_size, hidden_size)`")
if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = (reduce(operator.mul, tensor_shape, 1) // parallel_state.get_tensor_model_parallel_world_size(),)
else:
tensor_chunk_shape = tensor_shape
dtype = params_dtype or torch.float
if fp32_residual_connection:
dtype = torch.float
requires_grad = True
if dtype_ is not None:
dtype = dtype_
requires_grad = False
if recv_prev:
tensor_recv_prev = torch.empty(
tensor_chunk_shape,
requires_grad=requires_grad,
device=torch.cuda.current_device(),
dtype=dtype,
)
if recv_next:
tensor_recv_next = torch.empty(
tensor_chunk_shape,
requires_grad=requires_grad,
device=torch.cuda.current_device(),
dtype=dtype,
)
# Split tensor into smaller chunks if using scatter-gather optimization.
if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
if tensor_send_next is not None:
tensor_send_next = split_tensor_into_1d_equal_chunks(tensor_send_next)
if tensor_send_prev is not None:
tensor_send_prev = split_tensor_into_1d_equal_chunks(tensor_send_prev)
# Send tensors in both the forward and backward directions as appropriate.
_run_p2pops(tensor_send_prev, tensor_send_next, tensor_recv_prev, tensor_recv_next)
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
# If using scatter-gather optimization, gather smaller chunks.
if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
if recv_prev:
tensor_recv_prev = (
gather_split_1d_tensor(tensor_recv_prev)
.view(tensor_shape)
.requires_grad_()
)
if recv_next:
tensor_recv_next = (
gather_split_1d_tensor(tensor_recv_next)
.view(tensor_shape)
.requires_grad_()
)
return tensor_recv_prev, tensor_recv_next
def recv_forward(
tensor_shape: Shape,
override_scatter_gather_tensors_in_pipeline: bool = False,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> torch.Tensor:
"""Receive tensor from previous rank in pipeline (forward receive)."""
if parallel_state.is_pipeline_first_stage():
return None
if timers is not None:
timers("forward-recv").start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=True,
recv_next=False,
tensor_shape=tensor_shape,
override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("forward-recv").stop()
return input_tensor
def recv_backward(
tensor_shape: Shape = None,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
):
"""Receive tensor from next rank in pipeline (backward receive)."""
if parallel_state.is_pipeline_last_stage():
return None
if timers is not None:
timers("backward-recv").start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("backward-recv").stop()
return output_tensor_grad
def send_forward(
output_tensor: torch.Tensor,
override_scatter_gather_tensors_in_pipeline: bool = False,
tensor_shape: Shape = None,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> None:
"""Send tensor to next rank in pipeline (forward send)."""
if parallel_state.is_pipeline_last_stage():
return
if timers is not None:
timers("forward-send").start()
_communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=False,
override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("forward-send").stop()
def send_backward(
input_tensor_grad: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> None:
"""Send tensor to previous rank in pipeline (backward send)."""
if parallel_state.is_pipeline_first_stage():
return
if timers is not None:
timers("backward-send").start()
_communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("backward-send").stop()
def send_forward_recv_backward(
output_tensor: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> None:
"""Batched send and recv with next rank in pipeline."""
if parallel_state.is_pipeline_last_stage():
return None
if timers is not None:
timers("forward-send-backward-recv").start()
_, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("forward-send-backward-recv").stop()
return output_tensor_grad
def send_backward_recv_forward(
input_tensor_grad: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> torch.Tensor:
"""Batched send and recv with previous rank in pipeline."""
if parallel_state.is_pipeline_first_stage():
return None
if timers is not None:
timers("backward-send-forward-recv").start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=True,
recv_next=False,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("backward-send-forward-recv").stop()
return input_tensor
def send_forward_recv_forward(
output_tensor: torch.Tensor,
recv_prev: bool,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> torch.Tensor:
"""Batched recv from previous rank and send to next rank in pipeline."""
if timers is not None:
timers("forward-send-forward-recv").start()
input_tensor, _ = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=False,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("forward-send-forward-recv").stop()
return input_tensor
def send_backward_recv_backward(
input_tensor_grad: torch.Tensor,
recv_next: bool,
tensor_shape: Shape,
*,
dtype: torch.dtype = torch.float,
timers: _Timers = None,
) -> torch.Tensor:
"""Batched recv from next rank and send to previous rank in pipeline."""
if timers is not None:
timers("backward-send-backward-recv").start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("backward-send-backward-recv").stop()
return output_tensor_grad
def send_forward_backward_recv_forward_backward(
output_tensor: torch.Tensor,
input_tensor_grad: torch.Tensor,
recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
):
"""Batched send and recv with previous and next ranks in pipeline."""
if timers is not None:
timers("forward-backward-send-forward-backward-recv").start()
input_tensor, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("forward-backward-send-forward-backward-recv").stop()
return input_tensor, output_tensor_grad
import warnings
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import _forward_backward_pipelining_with_interleaving
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import (
forward_backward_pipelining_without_interleaving,
)
class ExperimentalWarning(Warning):
pass
def get_forward_backward_func(
virtual_pipeline_model_parallel_size, pipeline_model_parallel_size,
):
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if virtual_pipeline_model_parallel_size is not None:
if get_num_microbatches() % pipeline_model_parallel_size != 0:
msg = "number of microbatches is not divisible by pipeline-parallel size when using interleaved schedule"
raise RuntimeError(msg)
warnings.warn(
"Pipeline Model Parallel with interleaving scheduling is experimental. "
f"To use Pipeline Parallel without interleaving, set `virtual_pipeline_model_parallel_size` to `None`: {virtual_pipeline_model_parallel_size}",
ExperimentalWarning
)
forward_backward_func = _forward_backward_pipelining_with_interleaving
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
return forward_backward_func
__all__ = [
"get_forward_backward_func",
]
# NOTE (mkozuki): For simplicity, tentatively `timers` related operations are commented out.
from typing import Any, Callable, Dict, List, Tuple, Union, Optional
import torch
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import unwrap_model
from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes
Batch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]
LossFunc = Callable[[torch.Tensor], torch.Tensor]
FwdStepFunc = Callable[[Batch, torch.nn.Module], Tuple[torch.Tensor, LossFunc]]
def build_model(
model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module],
wrap_with_ddp: bool = True,
virtual_pipeline_model_parallel_size: Optional[int] = None,
*args,
**kwargs
) -> List[torch.nn.Module]:
"""Build the model satisfying pipeline model parallel requirements.
This function sets `pre_process` and `post_process` to `**kwargs` and pass `*args` and `**kwargs` to
`model_provider_func`.
Args:
model_provider_func: A function which takes `*args` and `**kwargs` and returns a `nn.Module`.
wrap_with_ddp: If :obj:`True`, wrap the instantiated model
with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`.
virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel.
*args: arguments for model provider func
**kwargs: Keyword arguments for model provider func
Returns:
a list of `nn.Module`(s). If `virtual_pipeline_model_parallel_size` is not None,
the list has multiple models, otherwise one.
"""
if (
parallel_state.get_pipeline_model_parallel_world_size() > 1 and
virtual_pipeline_model_parallel_size is not None
):
model = []
for i in range(virtual_pipeline_model_parallel_size):
cur_args = args
cur_kwargs = kwargs
parallel_state.set_virtual_pipeline_model_parallel_rank(i)
# Set pre_process and post_process only after virtual rank is set.
pre_process = parallel_state.is_pipeline_first_stage()
post_process = parallel_state.is_pipeline_last_stage()
cur_kwargs.update({
"pre_process": pre_process,
"post_process": post_process,
})
this_model = model_provider_func(*cur_args, **cur_kwargs)
model.append(this_model)
else:
cur_args = args
cur_kwargs = kwargs
pre_process = parallel_state.is_pipeline_first_stage()
post_process = parallel_state.is_pipeline_last_stage()
cur_kwargs.update({
"pre_process": pre_process,
"post_process": post_process,
})
model = model_provider_func(*cur_args, **cur_kwargs)
if not isinstance(model, list):
model = [model]
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
for model_module in model:
for param in model_module.parameters():
set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters.
if parallel_state.get_data_parallel_rank() == 0:
msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format(
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_pipeline_model_parallel_rank(),
sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model])
)
print(msg, flush=True)
# GPU allocation.
for model_module in model:
model_module.cuda(torch.cuda.current_device())
if wrap_with_ddp:
i = torch.cuda.current_device()
model = [
torch.nn.parallel.distributed.DistributedDataParallel(
model_module,
device_ids=[i],
output_device=i,
process_group=parallel_state.get_data_parallel_group(),
)
for model_module in model
]
return model
def _get_params_for_weight_decay_optimization(
model: Union[torch.nn.Module, List[torch.nn.Module]],
) -> Dict[str, torch.nn.Parameter]:
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and biases will have no weight decay but the rest will.
"""
modules = listify_model(model)
from apex.normalization.fused_layer_norm import FusedLayerNorm # NOQA
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module in modules:
for module_ in module.modules():
if isinstance(module_, FusedLayerNorm):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None])
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params
def forward_step(
forward_step_func: FwdStepFunc,
batch: Batch,
model: torch.nn.Module,
input_tensor: Optional[torch.Tensor],
losses_reduced: List[torch.Tensor],
):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor.
Args:
forward_step_func: Model specific function. This takes a minibatch and model as its arguments and
returns the model's output and the loss function.
batch: minibatch
model: unwrappable model
input_tensor:
losses_reduced:
Returns:
output_tensor
"""
# timers = get_timers()
# timers("forward-compute").start()
unwrapped_model = unwrap_model(model)
# NOTE (mkozuki): The passed `model` is expected to implement `set_input_tensor`.
# See https://github.com/NVIDIA/Megatron-LM/blob/5ac5571ba0265af4c491ee0af1508ca7589450c6/megatron/model/transformer.py#L679 # NOQA
# for the details of `set_input_tensor`.
unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(batch, model)
# print(f"forward_step| pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()} is_pipeline_last_stage?: {parallel_state.is_pipeline_last_stage()}")
if parallel_state.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
# timers("forward-compute").stop()
return output_tensor
def backward_step(
input_tensor: Optional[torch.Tensor],
output_tensor: torch.Tensor,
output_tensor_grad: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage).
Args:
input_tensor:
output_tensor:
output_tensor_grad:
Returns:
input_tensor_grad
"""
# timers = get_timers()
# timers("backward-compute").start()
# Retain the grad on the input_tensor.
# if parallel_state.get_pipeline_model_parallel_rank() == 0:
# print(f"{input_tensor}, {output_tensor}, {output_tensor_grad}")
if input_tensor is not None:
input_tensor.retain_grad()
# Backward pass.
# if output_tensor_grad is None:
# output_tensor = optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
input_tensor_grad = None
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
# timers("backward-compute").stop()
return input_tensor_grad
from contextlib import contextmanager
from typing import List, Union
import torch
from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.schedules.common import Batch, FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.schedules.common import backward_step
@contextmanager
def placeholder_handler():
try:
yield
finally:
pass
# TODO (mkozuki): Confirm this will be used or not.
# TODO (mkozuki): Fix if necessary. Currently I'm seeing failure if `not forward_only` and
# the last `backward_step` seems to fail. However, note the possibility of my test script is wrong.
def forward_backward_no_pipelining(
forward_step_func: FwdStepFunc,
batch: Batch,
model: Union[torch.nn.Module, List[torch.nn.Module]],
*,
forward_only: bool,
**kwargs,
):
"""Run forward and backward passes with no pipeline parallelism (no inter-stage communication).
This pipeline parallel scheduling handles the last microbatch differently to synchronize gradients.
Args:
forward_step_func: A function which takes a minibatch and model as its arguments and
returns model's forward output and the loss function.
The loss function is supposed to take one `torch.Tensor` and
return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
batch: A List of torch.Tensors
model: A `torch.nn.Module` or a list of `torch.nn.Module`.
Keyword args:
forward_only:
**kwargs: Added to handle `tensor_shape` which has no effect on this function.
Returns:
a list of dictionaries of loss `torch.Tensor`s if the last stage, empty list otherwise.
"""
model = listify_model(model)
if len(model) != 1:
msg = f"`model` is expected be a `nn.Module`, but {type(model)}"
raise RuntimeError(msg)
model = model[0]
context_handler = placeholder_handler
if isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel):
context_handler = model.no_sync
losses_reduced = []
input_tensor, output_tensor_grad = None, None
num_micro_batches = get_num_microbatches()
with context_handler():
for i in range(num_micro_batches - 1):
cur_micro_batch = get_kth_microbatch(batch, i)
output_tensor = forward_step(
forward_step_func, cur_micro_batch, model, input_tensor, losses_reduced)
if not forward_only:
backward_step(input_tensor, output_tensor, output_tensor_grad)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor = forward_step(
forward_step_func, get_kth_microbatch(batch, num_micro_batches - 1), model, input_tensor, losses_reduced
)
if not forward_only:
backward_step(input_tensor, output_tensor, output_tensor_grad)
return losses_reduced
from typing import List, Union, Optional
import torch
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import p2p_communication
from apex.transformer.pipeline_parallel.schedules.common import Batch, FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.utils import rank_print
# TODO (mkozuki): Reduce cyclomatic complexity
def _forward_backward_pipelining_with_interleaving(
forward_step_func: FwdStepFunc,
batch: List[Batch],
model: List[torch.nn.Module],
*,
forward_only: bool,
tensor_shape: Optional[Union[List[int], torch.Size]] = None,
):
"""Run interleaved 1F1B schedule with communication between pipeline stages as needed.
This function assumes `batch` and `model` is a list of `Batch`'s and a list of `torch.nn.Module`, respectively.
This means that model is split into model chunks.
This pipeline parallel scheduling consists of three steps:
1. warmup
2. 1F1B a.k.a. steady state
3. cooldown
Note that if `forward_only` this scheduling consists of only warmup phase.
Args:
forward_step_func: A function which takes a minibatch and model as its arguments and
returns model's forward output and the loss function.
The loss function is supposed to take one `torch.Tensor` and
return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
batch: A minibatch, i.e., a list of `torch.Tensor`'s.
model: A `torch.nn.Module` or a list of `torch.nn.Module`.
Keyword args:
forward_only:
tensor_shape: Shape of tensor.
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
"""
if not isinstance(model, list):
raise RuntimeError("`model` must be a list of `nn.Module`'s'")
# TODO (mkozuki): Sanity check the following condition.
if len(batch) != len(model):
msg = f"`batch` and `model` must have the same number of elements. Actual {len(batch)} and {len(model)}"
raise RuntimeError(msg)
num_model_chunks = len(model)
input_tensors = [[] for _ in range(num_model_chunks)]
output_tensors = [[] for _ in range(num_model_chunks)]
curr_iters = [0 for _ in range(num_model_chunks)]
losses_reduced = []
if not forward_only:
output_tensor_grads = [[] for _ in range(num_model_chunks)]
pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()
# Compute number of warmup and remaining microbatches.
num_microbatches = get_num_microbatches() * num_model_chunks
all_warmup_microbatches = False
if forward_only:
num_warmup_microbatches = num_microbatches
else:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if get_num_microbatches() == pipeline_parallel_size:
num_warmup_microbatches = num_microbatches
all_warmup_microbatches = True
else:
num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
# TODO (mkozuki): Remove once debug gets done
# rank_print(
# f"num_microbatches: {num_microbatches}, "
# f"num_warmup_microbatches: {num_warmup_microbatches}, "
# f"num_microbatches_remaining: {num_microbatches_remaining} -- "
# )
###################################################################################################################
# Helper function definitions.
###################################################################################################################
def get_model_chunk_id(microbatch_id: int, forward: bool) -> int:
"""Helper function to get the model chunk ID given the iteration number."""
pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
if not forward:
model_chunk_id = num_model_chunks - model_chunk_id - 1
return model_chunk_id
def forward_step_helper(microbatch_id, curr_iters):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
# forward step
if (
parallel_state.is_pipeline_first_stage() and
len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id])
):
input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1]
output_tensor = forward_step(
forward_step_func,
get_kth_microbatch(batch[model_chunk_id], curr_iters[model_chunk_id]),
model[model_chunk_id],
input_tensor,
losses_reduced,
)
curr_iters[model_chunk_id] += 1
output_tensors[model_chunk_id].append(output_tensor)
# if forward-only, no need to save tensors for a backward pass
if forward_only:
input_tensors[model_chunk_id].pop()
output_tensors[model_chunk_id].pop()
return output_tensor
def backward_step_helper(microbatch_id):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if parallel_state.is_pipeline_last_stage():
if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad)
return input_tensor_grad
###################################################################################################################
# Run warmup forward passes.
###################################################################################################################
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(p2p_communication.recv_forward(tensor_shape=tensor_shape))
for k in range(num_warmup_microbatches):
# rank_print(f"warmup iter: {k}")
output_tensor = forward_step_helper(k, curr_iters)
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
recv_prev = True
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
if next_forward_model_chunk_id == 0:
recv_prev = False
if k == (num_microbatches - 1):
recv_prev = False
# Don't send tensor downstream if on last stage.
if parallel_state.is_pipeline_last_stage():
output_tensor = None
# rank_print(f"recv_prev: {recv_prev}")
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if k == (num_warmup_microbatches - 1) and not forward_only and not all_warmup_microbatches:
input_tensor_grad = None
recv_next = True
# rank_print(f"recv_next: {recv_next}")
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
(
input_tensor,
output_tensor_grad,
) = p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor,
input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
else:
# rank_print("send_forward_recv_forward start")
input_tensor = p2p_communication.send_forward_recv_forward(output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape)
# rank_print("send_forward_recv_forward finish")
# rank_print("communication done")
input_tensors[next_forward_model_chunk_id].append(input_tensor)
###################################################################################################################
# Run 1F1B in steady state.
###################################################################################################################
for k in range(num_microbatches_remaining):
# Forward pass.
forward_k = k + num_warmup_microbatches
output_tensor = forward_step_helper(forward_k, curr_iters)
# Backward pass.
backward_k = k
input_tensor_grad = backward_step_helper(backward_k)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if parallel_state.is_pipeline_last_stage():
output_tensor = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if parallel_state.is_pipeline_first_stage():
input_tensor_grad = None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev = True
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id = get_model_chunk_id(
forward_k - (pipeline_parallel_size - 1), forward=True
)
if next_forward_model_chunk_id == (num_model_chunks - 1):
recv_prev = False
next_forward_model_chunk_id += 1
else:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id = get_model_chunk_id(
backward_k - (pipeline_parallel_size - 1), forward=False
)
if next_backward_model_chunk_id == 0:
recv_next = False
next_backward_model_chunk_id -= 1
else:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if k == (num_microbatches_remaining - 1):
recv_prev = False
# Communicate tensors.
(
input_tensor,
output_tensor_grad,
) = p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor,
input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if recv_prev:
input_tensors[next_forward_model_chunk_id].append(input_tensor)
if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
###################################################################################################################
# Run cooldown backward passes (flush out pipeline).
###################################################################################################################
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks - 1].append(p2p_communication.recv_backward(tensor_shape=tensor_shape))
for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
if next_backward_model_chunk_id == (num_model_chunks - 1):
recv_next = False
if k == (num_microbatches - 1):
recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward(input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape)
)
return losses_reduced
from typing import Union, List, Optional
import torch
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import p2p_communication
from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.schedules.common import Batch, FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.utils import rank_print
def forward_backward_pipelining_without_interleaving(
forward_step_func: FwdStepFunc,
batch: Batch,
model: Union[torch.nn.Module, List[torch.nn.Module]],
*,
forward_only: bool,
tensor_shape: Optional[Union[List[int], torch.Size]] = None,
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline stages.
This pipeline parallel scheduling consists of three steps:
1. warmup
2. 1F1B a.k.a. steady state
3. cooldown if not forward_only
Args:
forward_step_func: A function which takes a minibatch and model as its arguments and
returns model's forward output and the loss function.
The loss function is supposed to take one `torch.Tensor` and
return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
batch: A minibatch, i.e., a list of `torch.Tensor`'s.
model: A `torch.nn.Module` or a list of `torch.nn.Module`.
Keyword args:
forward_only:
tensor_shape: Shape of tensor. Required for P2P communication.
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
"""
# timers = get_timers()
model = listify_model(model)
if len(model) != 1:
msg = f"`model` is expected be a `nn.Module`, but {type(model)}"
raise RuntimeError(msg)
model = model[0]
# Compute number of warmup microbatches.
num_microbatches = get_num_microbatches()
num_warmup_microbatches = (
parallel_state.get_pipeline_model_parallel_world_size()
- parallel_state.get_pipeline_model_parallel_rank()
- 1
)
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
# TODO (mkozuki): Remove once debug gets done
print(
f">>> rank: {torch.distributed.get_rank()}, "
f"num_microbatches: {num_microbatches}, "
f"num_warmup_microbatches: {num_warmup_microbatches}, "
f"num_microbatches_remaining: {num_microbatches_remaining} -- "
)
# Input, output tensors only need to be saved when doing backward passes
input_tensors = None
output_tensors = None
if not forward_only:
input_tensors = []
output_tensors = []
losses_reduced = []
###################################################################################################################
# Run warmup forward passes.
###################################################################################################################
# rank_print(f"warmup: {num_warmup_microbatches}")
for i in range(num_warmup_microbatches):
input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape)
cur_microbatch = get_kth_microbatch(batch, i)
output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced)
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape)
if not forward_only:
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# rank_print(f"warmup iter: {i + 1} / {num_warmup_microbatches}")
# rank_print("warmup done")
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
# rank_print(f"num microbatches remaining: {num_microbatches_remaining}")
if num_microbatches_remaining > 0:
# rank_print(f"recv_forward before steady state start")
input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape)
# rank_print(f"recv_forward before steady state done")
###################################################################################################################
# Run 1F1B in steady state.
###################################################################################################################
# rank_print(f"steady: {num_microbatches_remaining} iters")
for i in range(num_microbatches_remaining):
# rank_print(f"steady: iter {i + 1} / {num_microbatches_remaining} iters")
# if not forward_only:
# rank_print(f"len(input_tensors) = {len(input_tensors)}, len(output_tensors) = {len(output_tensors)}")
last_iteration = i == (num_microbatches_remaining - 1)
cur_microbatch = get_kth_microbatch(batch, i + num_warmup_microbatches)
output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced)
if forward_only:
# rank_print(f"steady, no backward: `send_forward` start")
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape)
if not last_iteration:
input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape)
# rank_print(f"steady, no backward: `send_forward` finish")
else:
# rank_print("L.124 steady, backward: `send_forward_recv_backward` start")
output_tensor_grad = p2p_communication.send_forward_recv_backward(output_tensor, tensor_shape=tensor_shape)
# rank_print("L.124 steady, backward: `send_forward_recv_backward` finish")
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Pop input_tensor and output_tensor from the start of the list for the backward pass.
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad
)
if last_iteration:
input_tensor = None
# rank_print(f"L.142 steady backward last iteration: `send_backward` start")
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape)
# rank_print(f"L.142 steady backward last iteration: `send_backward` finish")
else:
# rank_print(f"L.146 steady backward: `send_backward_recv_forward` start")
input_tensor = p2p_communication.send_backward_recv_forward(
input_tensor_grad, tensor_shape=tensor_shape)
# rank_print(f"L.146 steady backward: `send_backward_recv_forward` finish")
# rank_print(f"steady: exit")
###################################################################################################################
# Run cooldown backward passes.
###################################################################################################################
if not forward_only:
# rank_print(f"cooldownk: {num_warmup_microbatches} iters")
for i in range(num_warmup_microbatches):
# rank_print(f"cooldown iter: {i + 1} / {num_warmup_microbatches}")
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
# rank_print(f"cooldown waiting for grad tensor")
output_tensor_grad = p2p_communication.recv_backward(tensor_shape=tensor_shape)
# rank_print(f"cooldown received grad tensor")
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad
)
# rank_print(f"cooldown sending grad tensor")
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape)
# rank_print(f"cooldownk exit")
return losses_reduced
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for pipeline model parallel."""
from typing import Optional, List, Union
import torch
from torch.nn.parallel import DistributedDataParallel
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from apex.transformer import parallel_state
from apex.transformer.microbatches import build_num_microbatches_calculator
from apex.transformer.pipeline_parallel._timers import _Timers
_GLOBAL_ARGS = None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
_GLOBAL_TOKENIZER = None
_GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_AUTORESUME = None
_GLOBAL_TIMERS = None
Shape = Union[List[int], torch.Size]
def listify_model(model: Union[torch.nn.Module, List[torch.nn.Module]]) -> List[torch.nn.Module]:
if isinstance(model, list):
return model
return [model]
def _ensure_var_is_initialized(var, name):
"""Make sure the input variable is not None."""
assert var is not None, "{} is not initialized.".format(name)
def _ensure_var_is_not_initialized(var, name):
"""Make sure the input variable is not None."""
assert var is None, "{} is already initialized.".format(name)
def setup_microbatch_calculator(
rank: int,
rampup_batch_size: Optional[List[int]],
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
) -> None:
global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
_ensure_var_is_not_initialized(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, 'num microbatches calculator')
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(
rank, rampup_batch_size, global_batch_size, micro_batch_size, data_parallel_size)
def get_micro_batch_size():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.micro_batch_size
def get_num_microbatches():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
def get_current_global_batch_size():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size()
def update_num_microbatches(consumed_samples, consistency_check=True):
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples, consistency_check)
# note (mkozuki): Comment out in favor of `get_kth_microbatch`
def _split_batch_into_microbatch(
batch: List[torch.Tensor],
*,
_micro_batch_size: Optional[int] = None,
_global_batch_size: Optional[int] = None,
) -> List[List[torch.Tensor]]:
micro_batch_size = _micro_batch_size
global_batch_size = _global_batch_size
if micro_batch_size is None:
micro_batch_size = get_micro_batch_size()
if global_batch_size is None:
global_batch_size = get_current_global_batch_size()
for i in range(0, global_batch_size, micro_batch_size):
yield [x[i * micro_batch_size:(i + 1) * micro_batch_size] for x in batch]
# TODO(mkozuki): Support non-tensor local minibatches?
def get_kth_microbatch(batch: List[torch.Tensor], k: int) -> List[torch.Tensor]:
"""Create a list of microbatches from a list of local minibatches.
This function creates a list of `k`th microbatches from a list of local minibatches.
`a local minibatch` consists of `global_batch_size / data_parallel_size` samples.
"""
micro_batch_size = get_micro_batch_size()
return [x[k * micro_batch_size:(k + 1) * micro_batch_size] for x in batch]
def get_autoresume():
return _GLOBAL_AUTORESUME
def _set_timers():
"""Initialize timers."""
global _GLOBAL_TIMERS
_ensure_var_is_not_initialized(_GLOBAL_TIMERS, "timers")
_GLOBAL_TIMERS = _Timers()
def get_timers():
"""Return timers."""
_ensure_var_is_initialized(_GLOBAL_TIMERS, "timers")
return _GLOBAL_TIMERS
def print_rank_0(message: str) -> None:
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
def is_last_rank():
return torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1)
def print_rank_last(message):
"""If distributed is initialized, print only on last rank."""
if torch.distributed.is_initialized():
if is_last_rank():
print(message, flush=True)
else:
print(message, flush=True)
def param_is_not_shared(param: torch.nn.Parameter) -> bool:
return getattr(param, "shared", False)
def unwrap_model(model, module_instances=(DistributedDataParallel,)):
return_list = True
if not isinstance(model, list):
model = [model]
return_list = False
unwrapped_model = []
for model_module in model:
while isinstance(model_module, module_instances):
model_module = model_module.module
unwrapped_model.append(model_module)
if not return_list:
return unwrapped_model[0]
return unwrapped_model
def calc_params_l2_norm(model: torch.nn.Module, bf16: bool):
"""Calculate l2 norm of parameters """
# args = get_args()
if not isinstance(model, list):
model = [model]
# Remove duplicate params.
params_data = []
for model_ in model:
for param in model_.parameters():
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = parallel_state.param_is_not_tensor_parallel_duplicate(param)
if is_not_shared and is_not_tp_duplicate:
if bf16:
params_data.append(param.data.float())
else:
params_data.append(param.data)
# Calculate norm
dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [params_data], False # no per-parameter norm
)
norm_2 = norm * norm
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(
norm_2, op=torch.distributed.ReduceOp.SUM, group=parallel_state.get_model_parallel_group()
)
return norm_2.item() ** 0.5
def average_losses_across_data_parallel_group(losses):
"""Reduce a tensor of losses across all GPUs."""
averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(averaged_losses, group=parallel_state.get_data_parallel_group())
averaged_losses = averaged_losses / torch.distributed.get_world_size(
group=parallel_state.get_data_parallel_group()
)
return averaged_losses
def report_memory(name):
"""Simple GPU memory report."""
mega_bytes = 1024.0 * 1024.0
string = name + " memory (MB)"
string += " | allocated: {}".format(torch.cuda.memory_allocated() / mega_bytes)
string += " | max allocated: {}".format(torch.cuda.max_memory_allocated() / mega_bytes)
string += " | reserved: {}".format(torch.cuda.memory_reserved() / mega_bytes)
string += " | max reserved: {}".format(torch.cuda.max_memory_reserved() / mega_bytes)
if parallel_state.get_data_parallel_rank() == 0:
print("[Rank {}] {}".format(torch.distributed.get_rank(), string), flush=True)
def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters."""
index = 0
rank = torch.distributed.get_rank()
string = "iteration, rank, index, tensor-model-parallel, min, max, norm\n"
optimizer_ = optimizer.optimizer
for param_group in optimizer_.param_groups:
for param in param_group["params"]:
index += 1
min_ = param.data.min()
max_ = param.data.max()
norm = torch.linalg.norm(param.data)
string += "{:7d}, {:4d}, {:4d}, {:2d}, ".format(
iteration, rank, index, int(param.tensor_model_parallel)
)
string += "{:.6E}, {:.6E}, {:.6E}\n".format(min_, max_, norm)
print(string, flush=True)
# NOTE (mkozuki): APEX doesn't have anything equivalent for
# `_GLOBAL_ADLR_AUTORESUME` like Megatron-LM.
# def check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler, save: bool):
# """Check for autoresume signal and exit if it is received."""
# from apex.ppu.checkpointing import save_checkpoint
#
# autoresume = get_adlr_autoresume()
# # Add barrier to ensure consistency.
# torch.distributed.barrier()
# if autoresume.termination_requested():
# if save:
# save_checkpoint(iteration, model, optimizer, lr_scheduler)
# print_rank_0(">>> autoresume termination request found!")
# if torch.distributed.get_rank() == 0:
# autoresume.request_resume()
# print_rank_0(">>> training terminated. Returning")
# sys.exit(0)
def get_ltor_masks_and_position_ids(
data, eod_token, reset_position_ids, reset_attention_mask, eod_mask_loss
):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
micro_batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(
torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
).view(att_mask_batch, 1, seq_length, seq_length)
# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(micro_batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1) :] -= i + 1 - prev_index
prev_index = i + 1
# Convert attention mask to binary:
attention_mask = attention_mask < 0.5
return attention_mask, loss_mask, position_ids
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -15,11 +15,11 @@
"""Model parallel utility interface."""
from .cross_entropy import vocab_parallel_cross_entropy
from apex.transformer.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data
from apex.transformer.tensor_parallel.data import broadcast_data
from .layers import (
from apex.transformer.tensor_parallel.layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
......@@ -28,7 +28,7 @@ from .layers import (
copy_tensor_model_parallel_attributes,
)
from .mappings import (
from apex.transformer.tensor_parallel.mappings import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
......@@ -41,11 +41,9 @@ from .random import (
init_checkpointed_activations_memory_buffer,
model_parallel_cuda_manual_seed,
reset_checkpointed_activations_memory_buffer,
gather_split_1d_tensor,
split_tensor_into_1d_equal_chunks,
)
from .utils import divide, split_tensor_along_last_dim
from apex.transformer.tensor_parallel.utils import split_tensor_along_last_dim
__all__ = [
......@@ -71,9 +69,6 @@ __all__ = [
"init_checkpointed_activations_memory_buffer",
"model_parallel_cuda_manual_seed",
"reset_checkpointed_activations_memory_buffer",
"gather_split_1d_tensor",
"split_tensor_into_1d_equal_chunks",
# utils.py
"divide",
"split_tensor_along_last_dim",
]
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,10 +14,10 @@
# limitations under the License.
import torch
from ..parallel_state import get_tensor_model_parallel_group
from ..parallel_state import get_tensor_model_parallel_rank
from ..parallel_state import get_tensor_model_parallel_world_size
from .utils import VocabUtility
from apex.transformer.parallel_state import get_tensor_model_parallel_group
from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from apex.transformer.parallel_state import get_tensor_model_parallel_world_size
from apex.transformer.tensor_parallel.utils import VocabUtility
class _VocabParallelCrossEntropy(torch.autograd.Function):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment