"docs/source/en/using-diffusers/automodel.md" did not exist on "6a2309b98d415d4ca1da69f59283507fe3eb1d73"
Commit 3c92fa93 authored by Jared Casper's avatar Jared Casper
Browse files

Move pipeline parallel functionality into core with associated changes.

parent 0b44909c
......@@ -17,7 +17,8 @@ from megatron import print_rank_0
from megatron.core import mpu
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel, ModelType
from megatron.model import GPTModel
from megatron.core.enums import ModelType
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import enum
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
......@@ -58,12 +58,40 @@ def initialize_model_parallel(
Initialize model data parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
pipeline).
pipeline_model_parallel_split_rank: for models with both encoder and decoder,
rank in pipeline with split point.
tensor_model_parallel_size (int, default = 1):
The number of GPUs to split individual tensors across.
pipeline_model_parallel_size (int, default = 1):
The number of tensor parallel GPU groups to split the
Transformer layers across. For example, if
tensor_model_parallel_size is 4 and
pipeline_model_parallel_size is 2, the model will be split
into 2 groups of 4 GPUs.
virtual_pipeline_model_parallel_size (int, optional):
The number of stages that each pipeline group will have,
interleaving as necessary. If None, no interleaving is
performed. For example, if tensor_model_parallel_size is 1,
pipeline_model_parallel_size is 4,
virtual_pipeline_model_parallel_size is 2, and there are
16 transformer layers in the model, the model will be
split into 8 stages with two layers each and each GPU
would get 2 stages as such (layer number starting with 1):
GPU 0: [1, 2] [9, 10]
GPU 1: [3, 4] [11, 12]
GPU 2: [5, 6] [13, 14]
GPU 3: [7, 8] [15, 16]
pipeline_model_parallel_split_rank (int, optional):
For models with both an encoder and decoder, the rank in
pipeline to switch between encoder and decoder (i.e. the
first rank of the decoder). This allows the user to set
the pipeline parallel size of the encoder and decoder
independently. For example, if
pipeline_model_parallel_size is 8 and
pipeline_model_parallel_split_rank is 3, then ranks 0-2
will be the encoder and ranks 3-7 will be the decoder.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
......@@ -298,8 +326,8 @@ def set_pipeline_model_parallel_rank(rank):
def set_pipeline_model_parallel_split_rank(rank):
"""Set pipeline model parallel split rank."""
global _MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
def get_tensor_model_parallel_rank():
......@@ -318,6 +346,11 @@ def get_pipeline_model_parallel_rank():
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def get_pipeline_model_parallel_split_rank():
"""Return pipeline model parallel split rank."""
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
......
from .schedules import get_forward_backward_func
......@@ -2,15 +2,24 @@
from functools import reduce
import operator
from typing import Optional, List, Union, Callable, Tuple
import torch
from megatron import get_args, core
from megatron.core import mpu
from megatron import core
from megatron.core.parallel_state import (
get_pipeline_model_parallel_group,
get_pipeline_model_parallel_prev_rank,
get_pipeline_model_parallel_next_rank,
)
# Types
Shape = Union[List[int], torch.Size]
def _communicate_shapes(tensor_send_next, tensor_send_prev,
recv_prev, recv_next):
"""Communicate tensor shapes between stages. Used to communicate
recv_prev, recv_next,
use_ring_exchange_p2p):
"""Communicate tensor shapes between stages. Used to communicate
tensor shapes before the actual tensor communication happens.
This is required when the sequence lengths across micro batches
are not uniform.
......@@ -28,7 +37,6 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev,
(recv_prev_shape, recv_next_shape)
"""
args = get_args()
recv_prev_shape_tensor = None
recv_next_shape_tensor = None
send_prev_shape_tensor = None
......@@ -50,7 +58,7 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev,
device=torch.cuda.current_device(),
dtype=torch.int64)
if args.use_ring_exchange_p2p:
if use_ring_exchange_p2p:
torch.distributed.ring_exchange(tensor_send_prev=send_prev_shape_tensor,
tensor_recv_prev=recv_prev_shape_tensor,
tensor_send_next=send_next_shape_tensor,
......@@ -98,46 +106,70 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev,
return recv_prev_shape, recv_next_shape
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_shape,
dtype_=None):
def _communicate(*, tensor_send_next: Optional[torch.Tensor],
tensor_send_prev: Optional[torch.Tensor],
recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
dtype: Optional[torch.dtype],
variable_seq_lengths: bool = False,
use_ring_exchange_p2p: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
Takes the following arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
tensor_shape: shape of tensor to receive (this method assumes that all
tensors sent and received in a single function call are
the same shape).
dtype_: optional, this is used when the tensor that needs to be
communicated is different from args.params_dtype.
Arguments:
tensor_send_next (torch.Tensor, optional):
Tensor to send to next rank (no tensor sent if None)
tensor_send_prev (torch.Tensor, optional):
Tensor to send to prev rank (no tensor sent if None)
recv_prev (boolean, required):
whether tensor should be received from previous rank.
recv_next (boolean, required):
whether tensor should be received from next rank.
tensor_shape (List[int] or torch.Size, required):
shape of tensor to receive (this method assumes that all
tensors sent and received in a single function call are
the same shape).
dtype (torch.dtype, required if either recv_{prev,next} is True):
this must be the type of the tensors that will be
received, will typically be params_dtype, but in the case
of fp32 residual connections might be torch.float.
variable_seq_lengths (bool, optional, default=False):
Support for variable sequence lengths across
microbatches. Setting this communicates the size of
tensors during pipeline parallelism communication, because
of this extra overhead it should only be set if the
sequence length is not constant during training.
use_ring_exchange_p2p (bool, optional, default = False):
Use custom ring_exchange kernel instead of
torch.distributed.batch_isend_irecv(). Requires custom
built torch with torch.distributed.ring_exchange.
Returns:
(tensor_recv_prev, tensor_recv_next)
tuple containing
- tensor_recv_prev: torch.Tensor if recv_prev is True, None otherwise.
- tensor_recv_next: torch.Tensor if recv_next is True, None otherwise.
"""
args = get_args()
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
# Some legacy inference code doesn't set the tensor shape, do so now
# for the normal values for gpt/bert. This could be removed if inference
# code is changed to provide tensor_shape.
if not args.variable_seq_lengths:
if tensor_shape is None:
recv_prev_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
recv_next_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
else:
recv_prev_shape = tensor_shape
recv_next_shape = tensor_shape
if not variable_seq_lengths:
recv_prev_shape = tensor_shape
recv_next_shape = tensor_shape
else:
recv_prev_shape, recv_next_shape = \
_communicate_shapes(tensor_send_next,
......@@ -145,116 +177,81 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
recv_prev,
recv_next)
override_scatter_gather_tensors_in_pipeline = False
if args.scatter_gather_tensors_in_pipeline and \
not args.sequence_parallel:
recv_prev_chunk_shape = reduce(operator.mul, recv_prev_shape, 1)
recv_next_chunk_shape = reduce(operator.mul, recv_next_shape, 1)
if recv_prev_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0 and \
recv_next_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0:
recv_prev_chunk_shape = recv_prev_chunk_shape // \
mpu.get_tensor_model_parallel_world_size()
recv_next_chunk_shape = recv_next_chunk_shape // \
mpu.get_tensor_model_parallel_world_size()
else:
recv_prev_chunk_shape = recv_prev_shape
recv_next_chunk_shape = recv_next_shape
override_scatter_gather_tensors_in_pipeline = True
else:
recv_prev_chunk_shape = recv_prev_shape
recv_next_chunk_shape = recv_next_shape
dtype = args.params_dtype
if args.fp32_residual_connection:
dtype = torch.float
requires_grad = True
if dtype_ is not None:
dtype = dtype_
requires_grad = False
if recv_prev:
tensor_recv_prev = torch.empty(recv_prev_chunk_shape,
requires_grad=requires_grad,
if dtype is None:
raise RuntimeError("dtype must be provided if recv_prev is True")
if tensor_shape is None:
raise RuntimeError(
"tensor_shape must be specified if recv_prev is True. "
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
)
tensor_recv_prev = torch.empty(recv_prev_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=dtype)
if recv_next:
tensor_recv_next = torch.empty(recv_next_chunk_shape,
requires_grad=requires_grad,
if dtype is None:
raise RuntimeError("dtype must be provided if recv_next is True")
if tensor_shape is None:
raise RuntimeError(
"tensor_shape must be specified if recv_next is True. "
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
)
tensor_recv_next = torch.empty(recv_next_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=dtype)
# Split tensor into smaller chunks if using scatter-gather optimization.
if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline and \
not args.sequence_parallel:
if tensor_send_next is not None:
tensor_send_next = core.tensor_parallel.split_tensor_into_1d_equal_chunks(tensor_send_next)
if tensor_send_prev is not None:
tensor_send_prev = core.tensor_parallel.split_tensor_into_1d_equal_chunks(tensor_send_prev)
# Send tensors in both the forward and backward directions as appropriate.
if args.use_ring_exchange_p2p:
if use_ring_exchange_p2p:
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next,
tensor_recv_next=tensor_recv_next,
group=mpu.get_pipeline_model_parallel_group())
group=get_pipeline_model_parallel_group())
else:
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_prev,
mpu.get_pipeline_model_parallel_prev_rank())
get_pipeline_model_parallel_prev_rank())
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_prev,
mpu.get_pipeline_model_parallel_prev_rank())
get_pipeline_model_parallel_prev_rank())
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_next,
mpu.get_pipeline_model_parallel_next_rank())
get_pipeline_model_parallel_next_rank())
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_next,
mpu.get_pipeline_model_parallel_next_rank())
get_pipeline_model_parallel_next_rank())
ops.append(recv_next_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
# User should assert that we have a modern enough PyTorch to not need this
torch.cuda.synchronize()
# If using scatter-gather optimization, gather smaller chunks.
if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline and \
not args.sequence_parallel:
if recv_prev:
tensor_recv_prev = core.tensor_parallel.gather_split_1d_tensor(
tensor_recv_prev).view(recv_prev_shape).requires_grad_()
tensor_recv_prev = core.utils.make_viewless_tensor(tensor_recv_prev,
requires_grad=True,
keep_graph=False)
if recv_next:
tensor_recv_next = core.tensor_parallel.gather_split_1d_tensor(
tensor_recv_next).view(recv_next_shape).requires_grad_()
tensor_recv_next = core.utils.make_viewless_tensor(tensor_recv_next,
requires_grad=True,
keep_graph=False)
return tensor_recv_prev, tensor_recv_next
def recv_forward(tensor_shape=None, dtype_=None, timers=None):
"""Receive tensor from previous rank in pipeline (forward receive)."""
def recv_forward(tensor_shape: Shape,
dtype: torch.dtype,
timers: Callable = None) -> torch.Tensor:
""" Receive tensor from previous rank in pipeline (forward receive).
if mpu.is_pipeline_first_stage():
See _communicate for argument details.
"""
if core.parallel_state.is_pipeline_first_stage():
input_tensor = None
else:
if timers is not None:
......@@ -265,15 +262,20 @@ def recv_forward(tensor_shape=None, dtype_=None, timers=None):
recv_prev=True,
recv_next=False,
tensor_shape=tensor_shape,
dtype_=dtype_)
dtype=dtype)
if timers is not None:
timers('forward-recv').stop()
return input_tensor
def recv_backward(tensor_shape=None, timers=None):
"""Receive tensor from next rank in pipeline (backward receive)."""
if mpu.is_pipeline_last_stage():
def recv_backward(tensor_shape: Shape,
dtype: torch.dtype,
timers: Callable = None) -> torch.Tensor:
"""Receive tensor from next rank in pipeline (backward receive).
See _communicate for argument details.
"""
if core.parallel_state.is_pipeline_last_stage():
output_tensor_grad = None
else:
if timers is not None:
......@@ -283,16 +285,21 @@ def recv_backward(tensor_shape=None, timers=None):
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
tensor_shape=tensor_shape)
tensor_shape=tensor_shape,
dtype=dtype)
if timers is not None:
timers('backward-recv').stop()
return output_tensor_grad
def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None):
"""Send tensor to next rank in pipeline (forward send)."""
def send_forward(output_tensor: torch.Tensor,
timers: Callable = None) -> None:
"""Send tensor to next rank in pipeline (forward send).
See _communicate for argument details.
"""
if not mpu.is_pipeline_last_stage():
if not core.parallel_state.is_pipeline_last_stage():
if timers is not None:
timers('forward-send', log_level=2).start()
_communicate(
......@@ -300,15 +307,19 @@ def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None):
tensor_send_prev=None,
recv_prev=False,
recv_next=False,
tensor_shape=tensor_shape,
dtype_=dtype_)
tensor_shape=None,
dtype=None)
if timers is not None:
timers('forward-send').stop()
def send_backward(input_tensor_grad, tensor_shape=None, timers=None):
"""Send tensor to previous rank in pipeline (backward send)."""
if not mpu.is_pipeline_first_stage():
def send_backward(input_tensor_grad: torch.Tensor,
timers: Callable = None) -> None:
"""Send tensor to previous rank in pipeline (backward send).
See _communicate for argument details.
"""
if not core.parallel_state.is_pipeline_first_stage():
if timers is not None:
timers('backward-send', log_level=2).start()
_communicate(
......@@ -316,14 +327,21 @@ def send_backward(input_tensor_grad, tensor_shape=None, timers=None):
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False,
tensor_shape=tensor_shape)
tensor_shape=None,
dtype=None)
if timers is not None:
timers('backward-send').stop()
def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None):
"""Batched send and recv with next rank in pipeline."""
if mpu.is_pipeline_last_stage():
def send_forward_recv_backward(output_tensor: torch.Tensor,
tensor_shape: Shape,
dtype: torch.dtype,
timers: Callable = None) -> torch.Tensor:
"""Batched send and recv with next rank in pipeline.
See _communicate for argument details.
"""
if core.parallel_state.is_pipeline_last_stage():
output_tensor_grad = None
else:
if timers is not None:
......@@ -333,15 +351,22 @@ def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None):
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
tensor_shape=tensor_shape)
tensor_shape=tensor_shape,
dtype=dtype)
if timers is not None:
timers('forward-send-backward-recv').stop()
return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None):
"""Batched send and recv with previous rank in pipeline."""
if mpu.is_pipeline_first_stage():
def send_backward_recv_forward(input_tensor_grad: torch.Tensor,
tensor_shape: Shape,
dtype: torch.dtype,
timers: Callable = None) -> torch.Tensor:
"""Batched send and recv with previous rank in pipeline.
See _communicate for argument details.
"""
if core.parallel_state.is_pipeline_first_stage():
input_tensor = None
else:
if timers is not None:
......@@ -351,14 +376,22 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None
tensor_send_prev=input_tensor_grad,
recv_prev=True,
recv_next=False,
tensor_shape=tensor_shape)
tensor_shape=tensor_shape,
dtype=dtype)
if timers is not None:
timers('backward-send-forward-recv').stop()
return input_tensor
def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timers=None):
"""Batched recv from previous rank and send to next rank in pipeline."""
def send_forward_recv_forward(output_tensor: torch.Tensor,
recv_prev: bool,
tensor_shape: Shape,
dtype: torch.dtype,
timers: Callable = None) -> torch.Tensor:
"""Batched recv from previous rank and send to next rank in pipeline.
See _communicate for argument details.
"""
if timers is not None:
timers('forward-send-forward-recv', log_level=2).start()
input_tensor, _ = _communicate(
......@@ -366,14 +399,22 @@ def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timer
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=False,
tensor_shape=tensor_shape)
tensor_shape=tensor_shape,
dtype=dtype)
if timers is not None:
timers('forward-send-forward-recv').stop()
return input_tensor
def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape=None, timers=None):
"""Batched recv from next rank and send to previous rank in pipeline."""
def send_backward_recv_backward(input_tensor_grad: torch.Tensor,
recv_next: bool,
tensor_shape: Shape,
dtype: torch.dtype,
timers: Callable = None) -> torch.Tensor:
"""Batched recv from next rank and send to previous rank in pipeline.
See _communicate for argument details.
"""
if timers is not None:
timers('backward-send-backward-recv', log_level=2).start()
_, output_tensor_grad = _communicate(
......@@ -381,16 +422,25 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=recv_next,
tensor_shape=tensor_shape)
tensor_shape=tensor_shape,
dtype=dtype)
if timers is not None:
timers('backward-send-backward-recv').stop()
return output_tensor_grad
def send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, recv_prev,
recv_next, tensor_shape=None, timers=None):
"""Batched send and recv with previous and next ranks in pipeline."""
output_tensor: torch.Tensor,
input_tensor_grad: torch.Tensor,
recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
dtype: torch.dtype,
timers: Callable = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""Batched send and recv with previous and next ranks in pipeline.
See _communicate for argument details.
"""
if timers is not None:
timers('forward-backward-send-forward-backward-recv',
log_level=2).start()
......@@ -399,7 +449,8 @@ def send_forward_backward_recv_forward_backward(
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape)
tensor_shape=tensor_shape,
dtype=dtype)
if timers is not None:
timers('forward-backward-send-forward-backward-recv').stop()
return input_tensor, output_tensor_grad
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from typing import Optional, List, Union, Callable, Any
import torch
from torch.autograd.variable import Variable
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import get_num_microbatches
from megatron import get_timers
from megatron import p2p_communication
from megatron.core import mpu
from megatron.utils import unwrap_model
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.model import ModelType
from megatron.core import parallel_state
from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.enums import ModelType
from megatron.core.utils import get_attr_wrapped_model, get_model_type
# Types
Shape = Union[List[int], torch.Size]
def get_forward_backward_func():
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
"""Retrieves the appropriate forward_backward function given the
configuration of parallel_state.
Returns a function that will perform all of the forward and
backward passes of the model given the pipeline model parallel
world size and virtual pipeline model parallel world size in the
global parallel_state.
The function returned takes the following arguments:
forward_step_func (required): A function that takes a data
iterator and a model as its arguments and return the model's
forward output and the loss function. The loss function should
take one torch.Tensor and return a torch.Tensor of loss and a
dictionary of string -> torch.Tensor.
For example:
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def forward_step(data_iterator, model):
data, loss_mask = next(data_iterator)
output = model(data)
return output, partial(loss_func, loss_mask)
forward_backward_func(forward_step_func=forward_step, ...)
data_iterator (required): an iterator over the data, will be
passed as is to forward_step_func
model (required): the actual model. A torch.nn.Module or, in the
case or iterleaving, a list of torch.nn.Module
num_microbatches (int, required):
The number of microbatches to go through
dtype (required when using pipeline parallelism): dtype used in
p2p communication, usually params_dtype
tensor_shape (required when using pipeline parallelism): Shape of
tensor. The tensor is expected to be 3D and its order of
dimension is supposed to be ``(sequence, batch, hidden)``.
decoder_seq_length (int, required for ModelType.encoder_and_decoder models):
Sequence length of the decoder portion, used to determine tensor shapes.
grad_scaler (optional, default=None): If using loss scaling,
this function should take the loss and return the scaled
loss. If None, no function is called on the loss.
sequence_parallel (optional, default=False):
Set to :obj:`True` for this function to handle sequence
length. When :obj:`True`, the sequence length on each tensor
model parallel rank is updated to
:math:`original\_sequence\_length /
tensor\_model\_parallel\_world\_size`.
TODO: Do we need this? Just roll into tensor_shape arg?
forward_only (optional, default=False): Perform only the forward step
timers (optional, default=None): TODO
collect_non_loss_data: TODO
"""
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
if pipeline_model_parallel_size > 1:
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
assert get_num_microbatches() % \
args.pipeline_model_parallel_size == 0, \
'number of microbatches (%d) is not divisible by pipeline-' \
'model-parallel-size (%d) when using interleaved schedule' % (
get_num_microbatches(),
args.pipeline_model_parallel_size,
)
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
......@@ -52,7 +119,7 @@ def deallocate_output_tensor(out):
device = out.device,
dtype = out.dtype,
)
def custom_backward(output, grad_output):
'''Directly call C++ autograd engine.
......@@ -87,11 +154,15 @@ def custom_backward(output, grad_output):
allow_unreachable=True,
accumulate_grad=True,
)
def forward_step(forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
timers,
......@@ -102,25 +173,26 @@ def forward_step(forward_step_func,
passed-in input_tensor is used.
Returns output tensor."""
args = get_args()
if timers is not None:
timers('forward-compute', log_level=2).start()
unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
unwrap_output_tensor = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_output_tensor = True
unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(data_iterator, model)
if mpu.is_pipeline_last_stage():
set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
set_input_tensor(input_tensor)
context_manager = torch.autocast("cuda") if torch.is_autocast_enabled() else nullcontext()
with context_manager:
output_tensor, loss_func = forward_step_func(data_iterator, model)
if parallel_state.is_pipeline_last_stage():
if not collect_non_loss_data:
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
output_tensor = loss / num_microbatches
forward_data_store.append(loss_reduced)
else:
data = loss_func(output_tensor, non_loss_data=True)
......@@ -132,16 +204,18 @@ def forward_step(forward_step_func,
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
if mpu.is_pipeline_stage_after_split() and \
args.model_type == ModelType.encoder_and_decoder:
model_type = get_model_type(model)
if parallel_state.is_pipeline_stage_after_split() and \
model_type == ModelType.encoder_and_decoder:
return [output_tensor, input_tensor[-1]]
if unwrap_output_tensor:
return output_tensor
return [output_tensor]
def backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad, timers):
def backward_step(grad_scaler, input_tensor, output_tensor,
output_tensor_grad, model_type, timers):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
......@@ -153,7 +227,6 @@ def backward_step(optimizer, input_tensor, output_tensor,
# NOTE: This code currently can handle at most one skip connection. It
# needs to be modified slightly to support arbitrary numbers of skip
# connections.
args = get_args()
if timers is not None:
timers('backward-compute', log_level=2).start()
......@@ -173,8 +246,8 @@ def backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad = [output_tensor_grad]
# Backward pass.
if output_tensor_grad[0] is None:
output_tensor = optimizer.scale_loss(output_tensor[0])
if output_tensor_grad[0] is None and grad_scaler is not None:
output_tensor = grad_scaler(output_tensor[0])
custom_backward(output_tensor[0], output_tensor_grad[0])
# Collect the grad of the input_tensor.
......@@ -189,9 +262,9 @@ def backward_step(optimizer, input_tensor, output_tensor,
# Handle single skip connection if it exists (encoder_hidden_state in
# model with encoder and decoder).
if mpu.get_pipeline_model_parallel_world_size() > 1 and \
mpu.is_pipeline_stage_after_split() and \
args.model_type == ModelType.encoder_and_decoder:
if parallel_state.get_pipeline_model_parallel_world_size() > 1 and \
parallel_state.is_pipeline_stage_after_split() and \
model_type == ModelType.encoder_and_decoder:
if output_tensor_grad[1] is not None:
input_tensor_grad[-1].add_(output_tensor_grad[1])
if unwrap_input_tensor_grad:
......@@ -211,16 +284,27 @@ def dummy_handler():
pass
def forward_backward_no_pipelining(forward_step_func,
data_iterator, model,
optimizer,
timers,
forward_only,
collect_non_loss_data=False):
def forward_backward_no_pipelining(*,
forward_step_func,
data_iterator,
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
dtype: Optional[torch.dtype] = None, # unused
tensor_shape: Optional[Shape] = None, # unused
decoder_seq_length: Optional[int] = None, # unused
grad_scaler: Callable = None,
sequence_parallel: bool = False, # unused
forward_only: bool = False,
timers: Callable = None,
collect_non_loss_data: bool = False):
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
Returns dictionary with losses."""
Returns dictionary with losses.
See get_forward_backward_func() for argument details
"""
assert len(model) == 1
model = model[0]
......@@ -228,64 +312,86 @@ def forward_backward_no_pipelining(forward_step_func,
if isinstance(model, torchDDP):
context_handler = model.no_sync
model_type = get_model_type(model)
forward_data_store = []
input_tensor, output_tensor_grad = None, None
with context_handler():
for i in range(get_num_microbatches() - 1):
for i in range(num_microbatches - 1):
output_tensor = forward_step(forward_step_func, data_iterator,
model, input_tensor, forward_data_store,
model, num_microbatches, input_tensor, forward_data_store,
timers, collect_non_loss_data)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad, timers)
backward_step(grad_scaler, input_tensor, output_tensor,
output_tensor_grad, model_type, timers)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor = forward_step(forward_step_func, data_iterator,
model, input_tensor, forward_data_store,
model, num_microbatches, input_tensor, forward_data_store,
timers, collect_non_loss_data)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad, timers)
backward_step(grad_scaler, input_tensor, output_tensor,
output_tensor_grad, model_type, timers)
return forward_data_store
def forward_backward_pipelining_with_interleaving(forward_step_func,
data_iterator, model,
optimizer,
timers,
forward_only,
collect_non_loss_data=False):
def forward_backward_pipelining_with_interleaving(*,
forward_step_func,
data_iterator,
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
dtype: torch.dtype,
tensor_shape: Shape,
decoder_seq_length: Optional[int] = None,
grad_scaler: Callable = None,
sequence_parallel: bool = False,
forward_only: bool = False,
timers: Callable = None,
collect_non_loss_data: bool = False):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
args = get_args()
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
forward_data_store = []
if not forward_only:
output_tensor_grads = [[] for _ in range(len(model))]
pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()
if num_microbatches % pipeline_parallel_size != 0:
msg = f'number of microbatches ({num_microbatches}) is not divisible by '
msg += f'pipeline-model-parallel-size ({pipeline_parallel_size}) '
msg += 'when using interleaved schedule'
raise RuntimeError(msg)
model_type = get_model_type(model[0])
if model_type == ModelType.encoder_and_decoder:
raise RuntimeError("Interleaving is not supported with an encoder and decoder model.")
if decoder_seq_length is not None and decoder_seq_length != tensor_shape[0]:
raise RuntimeError("Interleaving is not supported with a different decoder sequence length.")
if sequence_parallel:
seq_length, batch_size, hidden = tensor_shape
tensor_shape = (
seq_length // parallel_state.get_tensor_model_parallel_world_size(),
batch_size,
hidden,
)
if args.sequence_parallel:
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
else:
seq_length = args.seq_length
tensor_shape = (seq_length, args.micro_batch_size, args.hidden_size)
# Compute number of warmup and remaining microbatches.
num_model_chunks = len(model)
num_microbatches = get_num_microbatches() * num_model_chunks
total_num_microbatches = num_microbatches * num_model_chunks
all_warmup_microbatches = False
if forward_only:
num_warmup_microbatches = num_microbatches
num_warmup_microbatches = total_num_microbatches
else:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
......@@ -293,8 +399,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if get_num_microbatches() == pipeline_parallel_size:
num_warmup_microbatches = num_microbatches
if num_microbatches == pipeline_parallel_size:
num_warmup_microbatches = total_num_microbatches
all_warmup_microbatches = True
else:
num_warmup_microbatches = \
......@@ -302,9 +408,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
num_warmup_microbatches += (
num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches,
num_microbatches)
total_num_microbatches)
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
total_num_microbatches - num_warmup_microbatches
def get_model_chunk_id(microbatch_id, forward):
"""Helper method to get the model chunk ID given the iteration number."""
......@@ -319,10 +425,10 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
# forward step
if mpu.is_pipeline_first_stage():
if parallel_state.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == \
len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None)
......@@ -330,7 +436,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
output_tensor = forward_step(forward_step_func,
data_iterator[model_chunk_id],
model[model_chunk_id],
input_tensor,
num_microbatches,
input_tensor,
forward_data_store,
timers,
collect_non_loss_data)
......@@ -348,41 +455,42 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if mpu.is_pipeline_last_stage():
if parallel_state.is_pipeline_last_stage():
if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = \
backward_step(optimizer,
backward_step(grad_scaler,
input_tensor,
output_tensor,
output_tensor_grad,
model_type,
timers)
return input_tensor_grad
# Run warmup forward passes.
mpu.set_virtual_pipeline_model_parallel_rank(0)
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(
p2p_communication.recv_forward(tensor_shape, timers=timers))
p2p_communication.recv_forward(tensor_shape, dtype, timers=timers))
for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k)
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
recv_prev = True
if mpu.is_pipeline_first_stage(ignore_virtual=True):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
if next_forward_model_chunk_id == 0:
recv_prev = False
if k == (num_microbatches - 1):
if k == (total_num_microbatches - 1):
recv_prev = False
# Don't send tensor downstream if on last stage.
if mpu.is_pipeline_last_stage():
if parallel_state.is_pipeline_last_stage():
output_tensor = None
# Send and receive tensors as appropriate (send tensors computed
......@@ -391,20 +499,20 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
not all_warmup_microbatches:
input_tensor_grad = None
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
input_tensor, output_tensor_grad = \
p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next,
tensor_shape=tensor_shape,
tensor_shape=tensor_shape, dtype=dtype,
timers=timers)
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
else:
input_tensor = \
p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev,
tensor_shape=tensor_shape,
tensor_shape=tensor_shape, dtype=dtype,
timers=timers)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
deallocate_output_tensor(output_tensor)
......@@ -425,19 +533,19 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if mpu.is_pipeline_last_stage():
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if parallel_state.is_pipeline_last_stage():
output_tensor = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if mpu.is_pipeline_first_stage():
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if parallel_state.is_pipeline_first_stage():
input_tensor_grad = None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev = True
if mpu.is_pipeline_first_stage(ignore_virtual=True):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id = get_model_chunk_id(
forward_k - (pipeline_parallel_size - 1), forward=True)
......@@ -449,7 +557,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
forward=True)
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id = get_model_chunk_id(
backward_k - (pipeline_parallel_size - 1), forward=False)
......@@ -470,7 +578,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next,
tensor_shape=tensor_shape, timers=timers)
tensor_shape=tensor_shape, dtype=dtype, timers=timers)
deallocate_output_tensor(output_tensor)
# Put input_tensor and output_tensor_grad in data structures in the
......@@ -486,25 +594,29 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append(
p2p_communication.recv_backward(tensor_shape, timers=timers))
for k in range(num_microbatches_remaining, num_microbatches):
for k in range(num_microbatches_remaining, total_num_microbatches):
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
if next_backward_model_chunk_id == (num_model_chunks - 1):
recv_next = False
if k == (num_microbatches - 1):
if k == (total_num_microbatches - 1):
recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next=recv_next,
tensor_shape=tensor_shape,
tensor_shape=tensor_shape, dtype=dtype,
timers=timers))
return forward_data_store
def get_tensor_shapes(rank, model_type):
def get_tensor_shapes(*,
rank: int,
model_type: ModelType,
tensor_shape: Shape,
decoder_seq_length: int,
sequence_parallel: bool):
# Determine right tensor sizes (based on position of rank with respect to split
# rank) and model size.
# Send two tensors if model is T5 and rank is in decoder stage:
......@@ -513,48 +625,50 @@ def get_tensor_shapes(rank, model_type):
# If model is T5 and rank is at the boundary:
# send one tensor (post-transpose from encoder).
# Otherwise, send one tensor (pre-transpose).
args = get_args()
tensor_shapes = []
if args.sequence_parallel:
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
else:
seq_length = args.seq_length
assert (
len(tensor_shape) == 3
), f"`tensor_shape` should be [sequence_length, micro_batch_size, hidden_size] but {tensor_shape}"
seq_length, micro_batch_size, hidden_size = tensor_shape
if sequence_parallel:
seq_length = seq_length // parallel_state.get_tensor_model_parallel_world_size()
if model_type == ModelType.encoder_and_decoder:
if args.sequence_parallel:
decoder_seq_length = args.decoder_seq_length // mpu.get_tensor_model_parallel_world_size()
else:
decoder_seq_length = args.decoder_seq_length
if sequence_parallel:
decoder_seq_length = decoder_seq_length // parallel_state.get_tensor_model_parallel_world_size()
if mpu.is_pipeline_stage_before_split(rank):
tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
if parallel_state.is_pipeline_stage_before_split(rank):
tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
else:
tensor_shapes.append((decoder_seq_length, args.micro_batch_size, args.hidden_size))
tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
tensor_shapes.append((decoder_seq_length, micro_batch_size, hidden_size))
tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
else:
tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
return tensor_shapes
def recv_forward(tensor_shapes, timers):
def recv_forward(tensor_shapes, dtype, timers):
input_tensors = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
input_tensors.append(None)
else:
input_tensors.append(p2p_communication.recv_forward(tensor_shape,
input_tensors.append(p2p_communication.recv_forward(tensor_shape, dtype,
timers=timers))
return input_tensors
def recv_backward(tensor_shapes, timers):
def recv_backward(tensor_shapes, dtype, timers):
output_tensor_grads = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
output_tensor_grads.append(None)
else:
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape,
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape, dtype,
timers=timers))
return output_tensor_grads
......@@ -565,7 +679,7 @@ def send_forward(output_tensors, tensor_shapes, timers):
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_forward(output_tensor, tensor_shape, timers=timers)
p2p_communication.send_forward(output_tensor, timers=timers)
def send_backward(input_tensor_grads, tensor_shapes, timers):
......@@ -574,10 +688,10 @@ def send_backward(input_tensor_grads, tensor_shapes, timers):
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_backward(input_tensor_grad, tensor_shape, timers=timers)
p2p_communication.send_backward(input_tensor_grad, timers=timers)
def send_forward_recv_backward(output_tensors, tensor_shapes, timers):
def send_forward_recv_backward(output_tensors, tensor_shapes, dtype, timers):
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
output_tensor_grads = []
......@@ -586,12 +700,12 @@ def send_forward_recv_backward(output_tensors, tensor_shapes, timers):
output_tensor_grads.append(None)
continue
output_tensor_grad = p2p_communication.send_forward_recv_backward(
output_tensor, tensor_shape, timers=timers)
output_tensor, tensor_shape, dtype, timers=timers)
output_tensor_grads.append(output_tensor_grad)
return output_tensor_grads
def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
def send_backward_recv_forward(input_tensor_grads, tensor_shapes, dtype, timers):
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
input_tensors = []
......@@ -600,44 +714,55 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
input_tensors.append(None)
continue
input_tensor = p2p_communication.send_backward_recv_forward(
input_tensor_grad, tensor_shape, timers=timers)
input_tensor_grad, tensor_shape, dtype, timers=timers)
input_tensors.append(input_tensor)
return input_tensors
def forward_backward_pipelining_without_interleaving(forward_step_func,
def forward_backward_pipelining_without_interleaving(*,
forward_step_func,
data_iterator,
model,
optimizer,
timers,
forward_only,
collect_non_loss_data=False):
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
dtype: torch.dtype,
tensor_shape: Shape,
decoder_seq_length: Optional[int] = None,
grad_scaler: Callable = None,
sequence_parallel: bool = False,
forward_only: bool = False,
timers: Callable = None,
collect_non_loss_data: bool = False):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
Returns dictionary with losses if the last stage, empty dict otherwise."""
args = get_args()
assert len(model) == 1
model = model[0]
# Compute number of warmup microbatches.
num_microbatches = get_num_microbatches()
num_warmup_microbatches = \
(mpu.get_pipeline_model_parallel_world_size() -
mpu.get_pipeline_model_parallel_rank() - 1)
(parallel_state.get_pipeline_model_parallel_world_size() -
parallel_state.get_pipeline_model_parallel_rank() - 1)
num_warmup_microbatches = min(
num_warmup_microbatches,
num_microbatches)
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
model_type = unwrapped_model.model_type
rank = mpu.get_pipeline_model_parallel_rank()
recv_tensor_shapes = get_tensor_shapes(rank-1, model_type)
send_tensor_shapes = get_tensor_shapes(rank, model_type)
model_type = get_model_type(model)
rank = parallel_state.get_pipeline_model_parallel_rank()
recv_tensor_shapes = get_tensor_shapes(rank=rank-1,
model_type=model_type,
tensor_shape=tensor_shape,
decoder_seq_length=decoder_seq_length,
sequence_parallel=sequence_parallel)
send_tensor_shapes = get_tensor_shapes(rank=rank,
model_type=model_type,
tensor_shape=tensor_shape,
decoder_seq_length=decoder_seq_length,
sequence_parallel=sequence_parallel)
# Input, output tensors only need to be saved when doing backward passes
input_tensors = None
......@@ -649,10 +774,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, forward_data_store,
timers, collect_non_loss_data)
input_tensor = recv_forward(recv_tensor_shapes, dtype, timers=timers)
output_tensor = forward_step(forward_step_func, data_iterator, model, num_microbatches,input_tensor, forward_data_store,timers, collect_non_loss_data)
send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not forward_only:
......@@ -664,25 +787,26 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
input_tensor = recv_forward(recv_tensor_shapes, dtype, timers=timers)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = forward_step(forward_step_func, data_iterator, model,
output_tensor = forward_step(forward_step_func, data_iterator, model, num_microbatches,
input_tensor, forward_data_store,
timers, collect_non_loss_data)
if forward_only:
send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not last_iteration:
input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
input_tensor = recv_forward(recv_tensor_shapes, dtype, timers=timers)
else:
output_tensor_grad = \
send_forward_recv_backward(output_tensor,
send_tensor_shapes,
send_tensor_shapes, dtype,
timers=timers)
# Add input_tensor and output_tensor to end of list.
......@@ -696,8 +820,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
output_tensor = output_tensors.pop(0)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad, timers)
backward_step(grad_scaler, input_tensor, output_tensor,
output_tensor_grad, model_type, timers)
if last_iteration:
input_tensor = None
......@@ -705,7 +829,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
else:
input_tensor = \
send_backward_recv_forward(
input_tensor_grad, recv_tensor_shapes, timers=timers)
input_tensor_grad, recv_tensor_shapes, dtype, timers=timers)
# Run cooldown backward passes.
if not forward_only:
......@@ -713,11 +837,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = recv_backward(send_tensor_shapes, timers=timers)
output_tensor_grad = recv_backward(send_tensor_shapes, dtype, timers=timers)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad, timers)
backward_step(grad_scaler, input_tensor, output_tensor,
output_tensor_grad, model_type, timers)
send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
......
......@@ -13,6 +13,8 @@ import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter
from torch.cuda.amp import custom_fwd, custom_bwd
from megatron.core.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
......@@ -214,6 +216,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
"""See linear_with_grad_accumulation_and_async_allreduce"""
@staticmethod
@custom_fwd
def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
async_grad_allreduce, sequence_parallel):
ctx.save_for_backward(input, weight)
......@@ -243,6 +246,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
return output
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
......@@ -407,8 +411,8 @@ def linear_with_grad_accumulation_and_async_allreduce(
"maximum speedup")
linear_with_grad_accumulation_and_async_allreduce.warned = True
with torch.cuda.amp.autocast(enabled=False):
return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
linear_with_grad_accumulation_and_async_allreduce.warned = False
class ColumnParallelLinear(torch.nn.Module):
......
......@@ -20,6 +20,21 @@ def divide(numerator, denominator):
ensure_divisibility(numerator, denominator)
return numerator // denominator
def get_attr_wrapped_model(model, attr):
"""Get an attribute from a wrapped model"""
if isinstance(model, list):
raise RuntimeError("_get_attr_wrapped_model given a list of models")
while not hasattr(model, attr):
if not hasattr(model, "module"):
raise RuntimeError(f"_get_attr_wrapped_model couldn't find attribute {attr}")
model = model.module
return getattr(model, attr)
def get_model_type(model):
return get_attr_wrapped_model(model, 'model_type')
class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
......
......@@ -8,4 +8,3 @@ from .gpt_model import GPTModel
from .t5_model import T5Model
from .language_model import get_language_model
from .module import Float16Module
from .enums import ModelType
......@@ -2,10 +2,6 @@
import enum
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
class LayerType(enum.Enum):
encoder = 1
decoder = 2
......
......@@ -20,7 +20,8 @@ from megatron import get_args, get_retro_args, get_tensorboard_writer
from megatron.core import parallel_state
from megatron.core import tensor_parallel
from megatron.core import utils as core_utils
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
from megatron.core.enums import ModelType
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
......
......@@ -9,7 +9,8 @@ import torch.nn.functional as F
from megatron import get_timers, get_args, core, get_num_microbatches
from .module import MegatronModule
from megatron.core import mpu, tensor_parallel
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
from megatron.core.enums import ModelType
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
......
......@@ -25,8 +25,8 @@ from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.model import Float16Module
from megatron.model import ModelType
from megatron.model import GPTModel
from megatron.core.enums import ModelType
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
......@@ -37,7 +37,7 @@ from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm
from megatron.schedules import get_forward_backward_func
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.utils import report_memory
from megatron.model.vision.knn_monitor import compute_feature_bank
......@@ -400,6 +400,7 @@ def setup_model_and_optimizer(model_provider_func,
return model, optimizer, opt_param_scheduler
def train_step(forward_step_func, data_iterator,
model, optimizer, opt_param_scheduler):
"""Single training step."""
......@@ -418,8 +419,16 @@ def train_step(forward_step_func, data_iterator,
forward_backward_func = get_forward_backward_func()
fwd_bwd_timers = timers if args.timing_log_level > 1 else None
losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model,
optimizer, fwd_bwd_timers, forward_only=False)
forward_step_func=forward_step_func,
data_iterator=data_iterator,
model=model,
num_microbatches=get_num_microbatches(),
dtype=args.params_dtype,
tensor_shape=(args.seq_length, args.micro_batch_size, args.hidden_size),
grad_scaler=optimizer.scale_loss,
sequence_parallel=args.sequence_parallel,
forward_only=False,
timers=fwd_bwd_timers)
timers('forward-backward').stop()
# Empty unused memory.
......@@ -794,8 +803,15 @@ def evaluate(forward_step_func,
forward_backward_func = get_forward_backward_func()
loss_dicts = forward_backward_func(
forward_step_func, data_iterator, model, optimizer=None,
timers=None, forward_only=True)
forward_step_func=forward_step_func,
data_iterator=data_iterator,
model=model,
num_microbatches=get_num_microbatches(),
dtype=args.params_dtype,
tensor_shape=(args.seq_length, args.micro_batch_size, args.hidden_size),
sequence_parallel=args.sequence_parallel,
forward_only=True,
timers=None)
# Empty unused memory
if args.empty_unused_memory_level >= 1:
......
......@@ -11,8 +11,9 @@ from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron.core import tensor_parallel
from megatron.core.enums import ModelType
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel, ModelType
from megatron.model import BertModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
......
......@@ -9,8 +9,9 @@ from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron.core import tensor_parallel
from megatron.core.enums import ModelType
from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel, ModelType
from megatron.model import GPTModel
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group
......
......@@ -13,9 +13,9 @@ from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron.data.biencoder_dataset_utils import get_ict_batch
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import ModelType
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
......
......@@ -10,7 +10,8 @@ from megatron import get_timers
from megatron import get_tokenizer
from megatron import print_rank_0
from megatron.core import mpu, tensor_parallel
from megatron.model import GPTModel, ModelType
from megatron.core.enums import ModelType
from megatron.model import GPTModel
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from tools.retro.pretraining.retro_dataset import get_retro_datasets
......
......@@ -12,8 +12,9 @@ from megatron import (
print_rank_0
)
from megatron.core import tensor_parallel
from megatron.core.enums import ModelType
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import T5Model, ModelType
from megatron.model import T5Model
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
......
......@@ -6,8 +6,8 @@ import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers, print_rank_0
from megatron.core.enums import ModelType
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model import ModelType
from megatron.model.vision.classification import VitClassificationModel
from megatron.model.vision.classification import MitClassificationModel
from megatron.training import pretrain
......
......@@ -7,6 +7,7 @@ import numpy as np
import torch.distributed as dist
from functools import partial
from megatron import get_args, get_timers, print_rank_0
from megatron.core.enums import ModelType
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.dino import DINOPretrainModel
from megatron.model.vision.knn_monitor import knn_predict, get_feature_bank
......@@ -15,7 +16,6 @@ from megatron.utils import average_losses_across_data_parallel_group, unwrap_mod
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.model import ModelType
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment