Commit 3c92fa93 authored by Jared Casper's avatar Jared Casper
Browse files

Move pipeline parallel functionality into core with associated changes.

parent 0b44909c
...@@ -17,7 +17,8 @@ from megatron import print_rank_0 ...@@ -17,7 +17,8 @@ from megatron import print_rank_0
from megatron.core import mpu from megatron.core import mpu
from megatron.data.blendable_dataset import BlendableDataset from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.gpt_dataset import build_train_valid_test_datasets from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel, ModelType from megatron.model import GPTModel
from megatron.core.enums import ModelType
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import enum
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
...@@ -58,12 +58,40 @@ def initialize_model_parallel( ...@@ -58,12 +58,40 @@ def initialize_model_parallel(
Initialize model data parallel groups. Initialize model data parallel groups.
Arguments: Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model parallelism. tensor_model_parallel_size (int, default = 1):
pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism. The number of GPUs to split individual tensors across.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
pipeline). pipeline_model_parallel_size (int, default = 1):
pipeline_model_parallel_split_rank: for models with both encoder and decoder, The number of tensor parallel GPU groups to split the
rank in pipeline with split point. Transformer layers across. For example, if
tensor_model_parallel_size is 4 and
pipeline_model_parallel_size is 2, the model will be split
into 2 groups of 4 GPUs.
virtual_pipeline_model_parallel_size (int, optional):
The number of stages that each pipeline group will have,
interleaving as necessary. If None, no interleaving is
performed. For example, if tensor_model_parallel_size is 1,
pipeline_model_parallel_size is 4,
virtual_pipeline_model_parallel_size is 2, and there are
16 transformer layers in the model, the model will be
split into 8 stages with two layers each and each GPU
would get 2 stages as such (layer number starting with 1):
GPU 0: [1, 2] [9, 10]
GPU 1: [3, 4] [11, 12]
GPU 2: [5, 6] [13, 14]
GPU 3: [7, 8] [15, 16]
pipeline_model_parallel_split_rank (int, optional):
For models with both an encoder and decoder, the rank in
pipeline to switch between encoder and decoder (i.e. the
first rank of the decoder). This allows the user to set
the pipeline parallel size of the encoder and decoder
independently. For example, if
pipeline_model_parallel_size is 8 and
pipeline_model_parallel_split_rank is 3, then ranks 0-2
will be the encoder and ranks 3-7 will be the decoder.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
...@@ -298,8 +326,8 @@ def set_pipeline_model_parallel_rank(rank): ...@@ -298,8 +326,8 @@ def set_pipeline_model_parallel_rank(rank):
def set_pipeline_model_parallel_split_rank(rank): def set_pipeline_model_parallel_split_rank(rank):
"""Set pipeline model parallel split rank.""" """Set pipeline model parallel split rank."""
global _MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
def get_tensor_model_parallel_rank(): def get_tensor_model_parallel_rank():
...@@ -318,6 +346,11 @@ def get_pipeline_model_parallel_rank(): ...@@ -318,6 +346,11 @@ def get_pipeline_model_parallel_rank():
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def get_pipeline_model_parallel_split_rank():
"""Return pipeline model parallel split rank."""
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
def is_pipeline_first_stage(ignore_virtual=False): def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise.""" """Return True if in the first pipeline model-parallel stage, False otherwise."""
......
from .schedules import get_forward_backward_func
...@@ -2,15 +2,24 @@ ...@@ -2,15 +2,24 @@
from functools import reduce from functools import reduce
import operator import operator
from typing import Optional, List, Union, Callable, Tuple
import torch import torch
from megatron import get_args, core from megatron import core
from megatron.core import mpu from megatron.core.parallel_state import (
get_pipeline_model_parallel_group,
get_pipeline_model_parallel_prev_rank,
get_pipeline_model_parallel_next_rank,
)
# Types
Shape = Union[List[int], torch.Size]
def _communicate_shapes(tensor_send_next, tensor_send_prev, def _communicate_shapes(tensor_send_next, tensor_send_prev,
recv_prev, recv_next): recv_prev, recv_next,
"""Communicate tensor shapes between stages. Used to communicate use_ring_exchange_p2p):
"""Communicate tensor shapes between stages. Used to communicate
tensor shapes before the actual tensor communication happens. tensor shapes before the actual tensor communication happens.
This is required when the sequence lengths across micro batches This is required when the sequence lengths across micro batches
are not uniform. are not uniform.
...@@ -28,7 +37,6 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev, ...@@ -28,7 +37,6 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev,
(recv_prev_shape, recv_next_shape) (recv_prev_shape, recv_next_shape)
""" """
args = get_args()
recv_prev_shape_tensor = None recv_prev_shape_tensor = None
recv_next_shape_tensor = None recv_next_shape_tensor = None
send_prev_shape_tensor = None send_prev_shape_tensor = None
...@@ -50,7 +58,7 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev, ...@@ -50,7 +58,7 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=torch.int64) dtype=torch.int64)
if args.use_ring_exchange_p2p: if use_ring_exchange_p2p:
torch.distributed.ring_exchange(tensor_send_prev=send_prev_shape_tensor, torch.distributed.ring_exchange(tensor_send_prev=send_prev_shape_tensor,
tensor_recv_prev=recv_prev_shape_tensor, tensor_recv_prev=recv_prev_shape_tensor,
tensor_send_next=send_next_shape_tensor, tensor_send_next=send_next_shape_tensor,
...@@ -98,46 +106,70 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev, ...@@ -98,46 +106,70 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev,
return recv_prev_shape, recv_next_shape return recv_prev_shape, recv_next_shape
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, def _communicate(*, tensor_send_next: Optional[torch.Tensor],
tensor_shape, tensor_send_prev: Optional[torch.Tensor],
dtype_=None): recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
dtype: Optional[torch.dtype],
variable_seq_lengths: bool = False,
use_ring_exchange_p2p: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Communicate tensors between stages. Used as helper method in other """Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py. communication methods that are used in megatron/schedules.py.
Takes the following arguments: Arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if tensor_send_next (torch.Tensor, optional):
set to None). Tensor to send to next rank (no tensor sent if None)
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None). tensor_send_prev (torch.Tensor, optional):
recv_prev: boolean for whether tensor should be received from Tensor to send to prev rank (no tensor sent if None)
previous rank.
recv_next: boolean for whether tensor should be received from recv_prev (boolean, required):
next rank. whether tensor should be received from previous rank.
tensor_shape: shape of tensor to receive (this method assumes that all
tensors sent and received in a single function call are recv_next (boolean, required):
the same shape). whether tensor should be received from next rank.
dtype_: optional, this is used when the tensor that needs to be
communicated is different from args.params_dtype. tensor_shape (List[int] or torch.Size, required):
shape of tensor to receive (this method assumes that all
tensors sent and received in a single function call are
the same shape).
dtype (torch.dtype, required if either recv_{prev,next} is True):
this must be the type of the tensors that will be
received, will typically be params_dtype, but in the case
of fp32 residual connections might be torch.float.
variable_seq_lengths (bool, optional, default=False):
Support for variable sequence lengths across
microbatches. Setting this communicates the size of
tensors during pipeline parallelism communication, because
of this extra overhead it should only be set if the
sequence length is not constant during training.
use_ring_exchange_p2p (bool, optional, default = False):
Use custom ring_exchange kernel instead of
torch.distributed.batch_isend_irecv(). Requires custom
built torch with torch.distributed.ring_exchange.
Returns: Returns:
(tensor_recv_prev, tensor_recv_next) tuple containing
- tensor_recv_prev: torch.Tensor if recv_prev is True, None otherwise.
- tensor_recv_next: torch.Tensor if recv_next is True, None otherwise.
""" """
args = get_args()
# Create placeholder tensors for receive in forward and backward directions # Create placeholder tensors for receive in forward and backward directions
# if needed. # if needed.
tensor_recv_prev = None tensor_recv_prev = None
tensor_recv_next = None tensor_recv_next = None
# Some legacy inference code doesn't set the tensor shape, do so now if not variable_seq_lengths:
# for the normal values for gpt/bert. This could be removed if inference recv_prev_shape = tensor_shape
# code is changed to provide tensor_shape. recv_next_shape = tensor_shape
if not args.variable_seq_lengths:
if tensor_shape is None:
recv_prev_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
recv_next_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
else:
recv_prev_shape = tensor_shape
recv_next_shape = tensor_shape
else: else:
recv_prev_shape, recv_next_shape = \ recv_prev_shape, recv_next_shape = \
_communicate_shapes(tensor_send_next, _communicate_shapes(tensor_send_next,
...@@ -145,116 +177,81 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -145,116 +177,81 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
recv_prev, recv_prev,
recv_next) recv_next)
override_scatter_gather_tensors_in_pipeline = False
if args.scatter_gather_tensors_in_pipeline and \
not args.sequence_parallel:
recv_prev_chunk_shape = reduce(operator.mul, recv_prev_shape, 1)
recv_next_chunk_shape = reduce(operator.mul, recv_next_shape, 1)
if recv_prev_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0 and \
recv_next_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0:
recv_prev_chunk_shape = recv_prev_chunk_shape // \
mpu.get_tensor_model_parallel_world_size()
recv_next_chunk_shape = recv_next_chunk_shape // \
mpu.get_tensor_model_parallel_world_size()
else:
recv_prev_chunk_shape = recv_prev_shape
recv_next_chunk_shape = recv_next_shape
override_scatter_gather_tensors_in_pipeline = True
else:
recv_prev_chunk_shape = recv_prev_shape
recv_next_chunk_shape = recv_next_shape
dtype = args.params_dtype
if args.fp32_residual_connection:
dtype = torch.float
requires_grad = True
if dtype_ is not None:
dtype = dtype_
requires_grad = False
if recv_prev: if recv_prev:
tensor_recv_prev = torch.empty(recv_prev_chunk_shape, if dtype is None:
requires_grad=requires_grad, raise RuntimeError("dtype must be provided if recv_prev is True")
if tensor_shape is None:
raise RuntimeError(
"tensor_shape must be specified if recv_prev is True. "
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
)
tensor_recv_prev = torch.empty(recv_prev_shape,
requires_grad=True,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=dtype) dtype=dtype)
if recv_next: if recv_next:
tensor_recv_next = torch.empty(recv_next_chunk_shape, if dtype is None:
requires_grad=requires_grad, raise RuntimeError("dtype must be provided if recv_next is True")
if tensor_shape is None:
raise RuntimeError(
"tensor_shape must be specified if recv_next is True. "
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
)
tensor_recv_next = torch.empty(recv_next_shape,
requires_grad=True,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=dtype) dtype=dtype)
# Split tensor into smaller chunks if using scatter-gather optimization.
if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline and \
not args.sequence_parallel:
if tensor_send_next is not None:
tensor_send_next = core.tensor_parallel.split_tensor_into_1d_equal_chunks(tensor_send_next)
if tensor_send_prev is not None:
tensor_send_prev = core.tensor_parallel.split_tensor_into_1d_equal_chunks(tensor_send_prev)
# Send tensors in both the forward and backward directions as appropriate. # Send tensors in both the forward and backward directions as appropriate.
if args.use_ring_exchange_p2p: if use_ring_exchange_p2p:
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev, torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
tensor_recv_prev=tensor_recv_prev, tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next, tensor_send_next=tensor_send_next,
tensor_recv_next=tensor_recv_next, tensor_recv_next=tensor_recv_next,
group=mpu.get_pipeline_model_parallel_group()) group=get_pipeline_model_parallel_group())
else: else:
ops = [] ops = []
if tensor_send_prev is not None: if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp( send_prev_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_prev, torch.distributed.isend, tensor_send_prev,
mpu.get_pipeline_model_parallel_prev_rank()) get_pipeline_model_parallel_prev_rank())
ops.append(send_prev_op) ops.append(send_prev_op)
if tensor_recv_prev is not None: if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp( recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_prev, torch.distributed.irecv, tensor_recv_prev,
mpu.get_pipeline_model_parallel_prev_rank()) get_pipeline_model_parallel_prev_rank())
ops.append(recv_prev_op) ops.append(recv_prev_op)
if tensor_send_next is not None: if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp( send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_next, torch.distributed.isend, tensor_send_next,
mpu.get_pipeline_model_parallel_next_rank()) get_pipeline_model_parallel_next_rank())
ops.append(send_next_op) ops.append(send_next_op)
if tensor_recv_next is not None: if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp( recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_next, torch.distributed.irecv, tensor_recv_next,
mpu.get_pipeline_model_parallel_next_rank()) get_pipeline_model_parallel_next_rank())
ops.append(recv_next_op) ops.append(recv_next_op)
if len(ops) > 0: if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops) reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs: for req in reqs:
req.wait() req.wait()
# To protect against race condition when using batch_isend_irecv(). # To protect against race condition when using batch_isend_irecv().
# User should assert that we have a modern enough PyTorch to not need this
torch.cuda.synchronize() torch.cuda.synchronize()
# If using scatter-gather optimization, gather smaller chunks.
if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline and \
not args.sequence_parallel:
if recv_prev:
tensor_recv_prev = core.tensor_parallel.gather_split_1d_tensor(
tensor_recv_prev).view(recv_prev_shape).requires_grad_()
tensor_recv_prev = core.utils.make_viewless_tensor(tensor_recv_prev,
requires_grad=True,
keep_graph=False)
if recv_next:
tensor_recv_next = core.tensor_parallel.gather_split_1d_tensor(
tensor_recv_next).view(recv_next_shape).requires_grad_()
tensor_recv_next = core.utils.make_viewless_tensor(tensor_recv_next,
requires_grad=True,
keep_graph=False)
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
def recv_forward(tensor_shape=None, dtype_=None, timers=None): def recv_forward(tensor_shape: Shape,
"""Receive tensor from previous rank in pipeline (forward receive).""" dtype: torch.dtype,
timers: Callable = None) -> torch.Tensor:
""" Receive tensor from previous rank in pipeline (forward receive).
if mpu.is_pipeline_first_stage():
See _communicate for argument details.
"""
if core.parallel_state.is_pipeline_first_stage():
input_tensor = None input_tensor = None
else: else:
if timers is not None: if timers is not None:
...@@ -265,15 +262,20 @@ def recv_forward(tensor_shape=None, dtype_=None, timers=None): ...@@ -265,15 +262,20 @@ def recv_forward(tensor_shape=None, dtype_=None, timers=None):
recv_prev=True, recv_prev=True,
recv_next=False, recv_next=False,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=dtype_) dtype=dtype)
if timers is not None: if timers is not None:
timers('forward-recv').stop() timers('forward-recv').stop()
return input_tensor return input_tensor
def recv_backward(tensor_shape=None, timers=None): def recv_backward(tensor_shape: Shape,
"""Receive tensor from next rank in pipeline (backward receive).""" dtype: torch.dtype,
if mpu.is_pipeline_last_stage(): timers: Callable = None) -> torch.Tensor:
"""Receive tensor from next rank in pipeline (backward receive).
See _communicate for argument details.
"""
if core.parallel_state.is_pipeline_last_stage():
output_tensor_grad = None output_tensor_grad = None
else: else:
if timers is not None: if timers is not None:
...@@ -283,16 +285,21 @@ def recv_backward(tensor_shape=None, timers=None): ...@@ -283,16 +285,21 @@ def recv_backward(tensor_shape=None, timers=None):
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=False,
recv_next=True, recv_next=True,
tensor_shape=tensor_shape) tensor_shape=tensor_shape,
dtype=dtype)
if timers is not None: if timers is not None:
timers('backward-recv').stop() timers('backward-recv').stop()
return output_tensor_grad return output_tensor_grad
def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None): def send_forward(output_tensor: torch.Tensor,
"""Send tensor to next rank in pipeline (forward send).""" timers: Callable = None) -> None:
"""Send tensor to next rank in pipeline (forward send).
See _communicate for argument details.
"""
if not mpu.is_pipeline_last_stage(): if not core.parallel_state.is_pipeline_last_stage():
if timers is not None: if timers is not None:
timers('forward-send', log_level=2).start() timers('forward-send', log_level=2).start()
_communicate( _communicate(
...@@ -300,15 +307,19 @@ def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None): ...@@ -300,15 +307,19 @@ def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None):
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=False,
recv_next=False, recv_next=False,
tensor_shape=tensor_shape, tensor_shape=None,
dtype_=dtype_) dtype=None)
if timers is not None: if timers is not None:
timers('forward-send').stop() timers('forward-send').stop()
def send_backward(input_tensor_grad, tensor_shape=None, timers=None): def send_backward(input_tensor_grad: torch.Tensor,
"""Send tensor to previous rank in pipeline (backward send).""" timers: Callable = None) -> None:
if not mpu.is_pipeline_first_stage(): """Send tensor to previous rank in pipeline (backward send).
See _communicate for argument details.
"""
if not core.parallel_state.is_pipeline_first_stage():
if timers is not None: if timers is not None:
timers('backward-send', log_level=2).start() timers('backward-send', log_level=2).start()
_communicate( _communicate(
...@@ -316,14 +327,21 @@ def send_backward(input_tensor_grad, tensor_shape=None, timers=None): ...@@ -316,14 +327,21 @@ def send_backward(input_tensor_grad, tensor_shape=None, timers=None):
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=False, recv_prev=False,
recv_next=False, recv_next=False,
tensor_shape=tensor_shape) tensor_shape=None,
dtype=None)
if timers is not None: if timers is not None:
timers('backward-send').stop() timers('backward-send').stop()
def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None): def send_forward_recv_backward(output_tensor: torch.Tensor,
"""Batched send and recv with next rank in pipeline.""" tensor_shape: Shape,
if mpu.is_pipeline_last_stage(): dtype: torch.dtype,
timers: Callable = None) -> torch.Tensor:
"""Batched send and recv with next rank in pipeline.
See _communicate for argument details.
"""
if core.parallel_state.is_pipeline_last_stage():
output_tensor_grad = None output_tensor_grad = None
else: else:
if timers is not None: if timers is not None:
...@@ -333,15 +351,22 @@ def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None): ...@@ -333,15 +351,22 @@ def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None):
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=False,
recv_next=True, recv_next=True,
tensor_shape=tensor_shape) tensor_shape=tensor_shape,
dtype=dtype)
if timers is not None: if timers is not None:
timers('forward-send-backward-recv').stop() timers('forward-send-backward-recv').stop()
return output_tensor_grad return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None): def send_backward_recv_forward(input_tensor_grad: torch.Tensor,
"""Batched send and recv with previous rank in pipeline.""" tensor_shape: Shape,
if mpu.is_pipeline_first_stage(): dtype: torch.dtype,
timers: Callable = None) -> torch.Tensor:
"""Batched send and recv with previous rank in pipeline.
See _communicate for argument details.
"""
if core.parallel_state.is_pipeline_first_stage():
input_tensor = None input_tensor = None
else: else:
if timers is not None: if timers is not None:
...@@ -351,14 +376,22 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None ...@@ -351,14 +376,22 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=True, recv_prev=True,
recv_next=False, recv_next=False,
tensor_shape=tensor_shape) tensor_shape=tensor_shape,
dtype=dtype)
if timers is not None: if timers is not None:
timers('backward-send-forward-recv').stop() timers('backward-send-forward-recv').stop()
return input_tensor return input_tensor
def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timers=None): def send_forward_recv_forward(output_tensor: torch.Tensor,
"""Batched recv from previous rank and send to next rank in pipeline.""" recv_prev: bool,
tensor_shape: Shape,
dtype: torch.dtype,
timers: Callable = None) -> torch.Tensor:
"""Batched recv from previous rank and send to next rank in pipeline.
See _communicate for argument details.
"""
if timers is not None: if timers is not None:
timers('forward-send-forward-recv', log_level=2).start() timers('forward-send-forward-recv', log_level=2).start()
input_tensor, _ = _communicate( input_tensor, _ = _communicate(
...@@ -366,14 +399,22 @@ def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timer ...@@ -366,14 +399,22 @@ def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timer
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=recv_prev, recv_prev=recv_prev,
recv_next=False, recv_next=False,
tensor_shape=tensor_shape) tensor_shape=tensor_shape,
dtype=dtype)
if timers is not None: if timers is not None:
timers('forward-send-forward-recv').stop() timers('forward-send-forward-recv').stop()
return input_tensor return input_tensor
def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape=None, timers=None): def send_backward_recv_backward(input_tensor_grad: torch.Tensor,
"""Batched recv from next rank and send to previous rank in pipeline.""" recv_next: bool,
tensor_shape: Shape,
dtype: torch.dtype,
timers: Callable = None) -> torch.Tensor:
"""Batched recv from next rank and send to previous rank in pipeline.
See _communicate for argument details.
"""
if timers is not None: if timers is not None:
timers('backward-send-backward-recv', log_level=2).start() timers('backward-send-backward-recv', log_level=2).start()
_, output_tensor_grad = _communicate( _, output_tensor_grad = _communicate(
...@@ -381,16 +422,25 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape=None, ...@@ -381,16 +422,25 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=False, recv_prev=False,
recv_next=recv_next, recv_next=recv_next,
tensor_shape=tensor_shape) tensor_shape=tensor_shape,
dtype=dtype)
if timers is not None: if timers is not None:
timers('backward-send-backward-recv').stop() timers('backward-send-backward-recv').stop()
return output_tensor_grad return output_tensor_grad
def send_forward_backward_recv_forward_backward( def send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, recv_prev, output_tensor: torch.Tensor,
recv_next, tensor_shape=None, timers=None): input_tensor_grad: torch.Tensor,
"""Batched send and recv with previous and next ranks in pipeline.""" recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
dtype: torch.dtype,
timers: Callable = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""Batched send and recv with previous and next ranks in pipeline.
See _communicate for argument details.
"""
if timers is not None: if timers is not None:
timers('forward-backward-send-forward-backward-recv', timers('forward-backward-send-forward-backward-recv',
log_level=2).start() log_level=2).start()
...@@ -399,7 +449,8 @@ def send_forward_backward_recv_forward_backward( ...@@ -399,7 +449,8 @@ def send_forward_backward_recv_forward_backward(
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev, recv_prev=recv_prev,
recv_next=recv_next, recv_next=recv_next,
tensor_shape=tensor_shape) tensor_shape=tensor_shape,
dtype=dtype)
if timers is not None: if timers is not None:
timers('forward-backward-send-forward-backward-recv').stop() timers('forward-backward-send-forward-backward-recv').stop()
return input_tensor, output_tensor_grad return input_tensor, output_tensor_grad
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from contextlib import contextmanager from contextlib import contextmanager, nullcontext
from typing import Optional, List, Union, Callable, Any
import torch import torch
from torch.autograd.variable import Variable from torch.autograd.variable import Variable
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args from megatron.core import parallel_state
from megatron import get_num_microbatches from megatron.core.pipeline_parallel import p2p_communication
from megatron import get_timers from megatron.core.enums import ModelType
from megatron import p2p_communication from megatron.core.utils import get_attr_wrapped_model, get_model_type
from megatron.core import mpu
from megatron.utils import unwrap_model
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.model import ModelType
# Types
Shape = Union[List[int], torch.Size]
def get_forward_backward_func(): def get_forward_backward_func():
args = get_args() """Retrieves the appropriate forward_backward function given the
if mpu.get_pipeline_model_parallel_world_size() > 1: configuration of parallel_state.
if args.virtual_pipeline_model_parallel_size is not None:
Returns a function that will perform all of the forward and
backward passes of the model given the pipeline model parallel
world size and virtual pipeline model parallel world size in the
global parallel_state.
The function returned takes the following arguments:
forward_step_func (required): A function that takes a data
iterator and a model as its arguments and return the model's
forward output and the loss function. The loss function should
take one torch.Tensor and return a torch.Tensor of loss and a
dictionary of string -> torch.Tensor.
For example:
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def forward_step(data_iterator, model):
data, loss_mask = next(data_iterator)
output = model(data)
return output, partial(loss_func, loss_mask)
forward_backward_func(forward_step_func=forward_step, ...)
data_iterator (required): an iterator over the data, will be
passed as is to forward_step_func
model (required): the actual model. A torch.nn.Module or, in the
case or iterleaving, a list of torch.nn.Module
num_microbatches (int, required):
The number of microbatches to go through
dtype (required when using pipeline parallelism): dtype used in
p2p communication, usually params_dtype
tensor_shape (required when using pipeline parallelism): Shape of
tensor. The tensor is expected to be 3D and its order of
dimension is supposed to be ``(sequence, batch, hidden)``.
decoder_seq_length (int, required for ModelType.encoder_and_decoder models):
Sequence length of the decoder portion, used to determine tensor shapes.
grad_scaler (optional, default=None): If using loss scaling,
this function should take the loss and return the scaled
loss. If None, no function is called on the loss.
sequence_parallel (optional, default=False):
Set to :obj:`True` for this function to handle sequence
length. When :obj:`True`, the sequence length on each tensor
model parallel rank is updated to
:math:`original\_sequence\_length /
tensor\_model\_parallel\_world\_size`.
TODO: Do we need this? Just roll into tensor_shape arg?
forward_only (optional, default=False): Perform only the forward step
timers (optional, default=None): TODO
collect_non_loss_data: TODO
"""
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
if pipeline_model_parallel_size > 1:
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving forward_backward_func = forward_backward_pipelining_with_interleaving
assert get_num_microbatches() % \
args.pipeline_model_parallel_size == 0, \
'number of microbatches (%d) is not divisible by pipeline-' \
'model-parallel-size (%d) when using interleaved schedule' % (
get_num_microbatches(),
args.pipeline_model_parallel_size,
)
else: else:
forward_backward_func = forward_backward_pipelining_without_interleaving forward_backward_func = forward_backward_pipelining_without_interleaving
else: else:
...@@ -52,7 +119,7 @@ def deallocate_output_tensor(out): ...@@ -52,7 +119,7 @@ def deallocate_output_tensor(out):
device = out.device, device = out.device,
dtype = out.dtype, dtype = out.dtype,
) )
def custom_backward(output, grad_output): def custom_backward(output, grad_output):
'''Directly call C++ autograd engine. '''Directly call C++ autograd engine.
...@@ -87,11 +154,15 @@ def custom_backward(output, grad_output): ...@@ -87,11 +154,15 @@ def custom_backward(output, grad_output):
allow_unreachable=True, allow_unreachable=True,
accumulate_grad=True, accumulate_grad=True,
) )
def forward_step(forward_step_func, def forward_step(forward_step_func,
data_iterator, data_iterator,
model, model,
num_microbatches,
input_tensor, input_tensor,
forward_data_store, forward_data_store,
timers, timers,
...@@ -102,25 +173,26 @@ def forward_step(forward_step_func, ...@@ -102,25 +173,26 @@ def forward_step(forward_step_func,
passed-in input_tensor is used. passed-in input_tensor is used.
Returns output tensor.""" Returns output tensor."""
args = get_args()
if timers is not None: if timers is not None:
timers('forward-compute', log_level=2).start() timers('forward-compute', log_level=2).start()
unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
unwrap_output_tensor = False unwrap_output_tensor = False
if not isinstance(input_tensor, list): if not isinstance(input_tensor, list):
input_tensor = [input_tensor] input_tensor = [input_tensor]
unwrap_output_tensor = True unwrap_output_tensor = True
unwrapped_model.set_input_tensor(input_tensor) set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
output_tensor, loss_func = forward_step_func(data_iterator, model) set_input_tensor(input_tensor)
if mpu.is_pipeline_last_stage():
context_manager = torch.autocast("cuda") if torch.is_autocast_enabled() else nullcontext()
with context_manager:
output_tensor, loss_func = forward_step_func(data_iterator, model)
if parallel_state.is_pipeline_last_stage():
if not collect_non_loss_data: if not collect_non_loss_data:
output_tensor = loss_func(output_tensor) output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches() output_tensor = loss / num_microbatches
forward_data_store.append(loss_reduced) forward_data_store.append(loss_reduced)
else: else:
data = loss_func(output_tensor, non_loss_data=True) data = loss_func(output_tensor, non_loss_data=True)
...@@ -132,16 +204,18 @@ def forward_step(forward_step_func, ...@@ -132,16 +204,18 @@ def forward_step(forward_step_func,
# If T5 model (or other model with encoder and decoder) # If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state # and in decoder stack, then send encoder_hidden_state
# downstream as well. # downstream as well.
if mpu.is_pipeline_stage_after_split() and \ model_type = get_model_type(model)
args.model_type == ModelType.encoder_and_decoder:
if parallel_state.is_pipeline_stage_after_split() and \
model_type == ModelType.encoder_and_decoder:
return [output_tensor, input_tensor[-1]] return [output_tensor, input_tensor[-1]]
if unwrap_output_tensor: if unwrap_output_tensor:
return output_tensor return output_tensor
return [output_tensor] return [output_tensor]
def backward_step(optimizer, input_tensor, output_tensor, def backward_step(grad_scaler, input_tensor, output_tensor,
output_tensor_grad, timers): output_tensor_grad, model_type, timers):
"""Backward step through passed-in output tensor. """Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss If last stage, output_tensor_grad is None, otherwise gradient of loss
...@@ -153,7 +227,6 @@ def backward_step(optimizer, input_tensor, output_tensor, ...@@ -153,7 +227,6 @@ def backward_step(optimizer, input_tensor, output_tensor,
# NOTE: This code currently can handle at most one skip connection. It # NOTE: This code currently can handle at most one skip connection. It
# needs to be modified slightly to support arbitrary numbers of skip # needs to be modified slightly to support arbitrary numbers of skip
# connections. # connections.
args = get_args()
if timers is not None: if timers is not None:
timers('backward-compute', log_level=2).start() timers('backward-compute', log_level=2).start()
...@@ -173,8 +246,8 @@ def backward_step(optimizer, input_tensor, output_tensor, ...@@ -173,8 +246,8 @@ def backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad = [output_tensor_grad] output_tensor_grad = [output_tensor_grad]
# Backward pass. # Backward pass.
if output_tensor_grad[0] is None: if output_tensor_grad[0] is None and grad_scaler is not None:
output_tensor = optimizer.scale_loss(output_tensor[0]) output_tensor = grad_scaler(output_tensor[0])
custom_backward(output_tensor[0], output_tensor_grad[0]) custom_backward(output_tensor[0], output_tensor_grad[0])
# Collect the grad of the input_tensor. # Collect the grad of the input_tensor.
...@@ -189,9 +262,9 @@ def backward_step(optimizer, input_tensor, output_tensor, ...@@ -189,9 +262,9 @@ def backward_step(optimizer, input_tensor, output_tensor,
# Handle single skip connection if it exists (encoder_hidden_state in # Handle single skip connection if it exists (encoder_hidden_state in
# model with encoder and decoder). # model with encoder and decoder).
if mpu.get_pipeline_model_parallel_world_size() > 1 and \ if parallel_state.get_pipeline_model_parallel_world_size() > 1 and \
mpu.is_pipeline_stage_after_split() and \ parallel_state.is_pipeline_stage_after_split() and \
args.model_type == ModelType.encoder_and_decoder: model_type == ModelType.encoder_and_decoder:
if output_tensor_grad[1] is not None: if output_tensor_grad[1] is not None:
input_tensor_grad[-1].add_(output_tensor_grad[1]) input_tensor_grad[-1].add_(output_tensor_grad[1])
if unwrap_input_tensor_grad: if unwrap_input_tensor_grad:
...@@ -211,16 +284,27 @@ def dummy_handler(): ...@@ -211,16 +284,27 @@ def dummy_handler():
pass pass
def forward_backward_no_pipelining(forward_step_func, def forward_backward_no_pipelining(*,
data_iterator, model, forward_step_func,
optimizer, data_iterator,
timers, model: Union[torch.nn.Module, List[torch.nn.Module]],
forward_only, num_microbatches: int,
collect_non_loss_data=False): dtype: Optional[torch.dtype] = None, # unused
tensor_shape: Optional[Shape] = None, # unused
decoder_seq_length: Optional[int] = None, # unused
grad_scaler: Callable = None,
sequence_parallel: bool = False, # unused
forward_only: bool = False,
timers: Callable = None,
collect_non_loss_data: bool = False):
"""Run forward and backward passes with no pipeline parallelism """Run forward and backward passes with no pipeline parallelism
(no inter-stage communication). (no inter-stage communication).
Returns dictionary with losses.""" Returns dictionary with losses.
See get_forward_backward_func() for argument details
"""
assert len(model) == 1 assert len(model) == 1
model = model[0] model = model[0]
...@@ -228,64 +312,86 @@ def forward_backward_no_pipelining(forward_step_func, ...@@ -228,64 +312,86 @@ def forward_backward_no_pipelining(forward_step_func,
if isinstance(model, torchDDP): if isinstance(model, torchDDP):
context_handler = model.no_sync context_handler = model.no_sync
model_type = get_model_type(model)
forward_data_store = [] forward_data_store = []
input_tensor, output_tensor_grad = None, None input_tensor, output_tensor_grad = None, None
with context_handler(): with context_handler():
for i in range(get_num_microbatches() - 1): for i in range(num_microbatches - 1):
output_tensor = forward_step(forward_step_func, data_iterator, output_tensor = forward_step(forward_step_func, data_iterator,
model, input_tensor, forward_data_store, model, num_microbatches, input_tensor, forward_data_store,
timers, collect_non_loss_data) timers, collect_non_loss_data)
if not forward_only: if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, backward_step(grad_scaler, input_tensor, output_tensor,
output_tensor_grad, timers) output_tensor_grad, model_type, timers)
# Run computation for last microbatch out of context handler (want to # Run computation for last microbatch out of context handler (want to
# synchronize gradients). # synchronize gradients).
output_tensor = forward_step(forward_step_func, data_iterator, output_tensor = forward_step(forward_step_func, data_iterator,
model, input_tensor, forward_data_store, model, num_microbatches, input_tensor, forward_data_store,
timers, collect_non_loss_data) timers, collect_non_loss_data)
if not forward_only: if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, backward_step(grad_scaler, input_tensor, output_tensor,
output_tensor_grad, timers) output_tensor_grad, model_type, timers)
return forward_data_store return forward_data_store
def forward_backward_pipelining_with_interleaving(forward_step_func, def forward_backward_pipelining_with_interleaving(*,
data_iterator, model, forward_step_func,
optimizer, data_iterator,
timers, model: Union[torch.nn.Module, List[torch.nn.Module]],
forward_only, num_microbatches: int,
collect_non_loss_data=False): dtype: torch.dtype,
tensor_shape: Shape,
decoder_seq_length: Optional[int] = None,
grad_scaler: Callable = None,
sequence_parallel: bool = False,
forward_only: bool = False,
timers: Callable = None,
collect_non_loss_data: bool = False):
"""Run interleaved 1F1B schedule (model split into model chunks), with """Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed. communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise.""" Returns dictionary with losses if the last stage, empty dict otherwise."""
args = get_args()
input_tensors = [[] for _ in range(len(model))] input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))] output_tensors = [[] for _ in range(len(model))]
forward_data_store = [] forward_data_store = []
if not forward_only: if not forward_only:
output_tensor_grads = [[] for _ in range(len(model))] output_tensor_grads = [[] for _ in range(len(model))]
pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size() pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank() pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()
if num_microbatches % pipeline_parallel_size != 0:
msg = f'number of microbatches ({num_microbatches}) is not divisible by '
msg += f'pipeline-model-parallel-size ({pipeline_parallel_size}) '
msg += 'when using interleaved schedule'
raise RuntimeError(msg)
model_type = get_model_type(model[0])
if model_type == ModelType.encoder_and_decoder:
raise RuntimeError("Interleaving is not supported with an encoder and decoder model.")
if decoder_seq_length is not None and decoder_seq_length != tensor_shape[0]:
raise RuntimeError("Interleaving is not supported with a different decoder sequence length.")
if sequence_parallel:
seq_length, batch_size, hidden = tensor_shape
tensor_shape = (
seq_length // parallel_state.get_tensor_model_parallel_world_size(),
batch_size,
hidden,
)
if args.sequence_parallel:
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
else:
seq_length = args.seq_length
tensor_shape = (seq_length, args.micro_batch_size, args.hidden_size)
# Compute number of warmup and remaining microbatches. # Compute number of warmup and remaining microbatches.
num_model_chunks = len(model) num_model_chunks = len(model)
num_microbatches = get_num_microbatches() * num_model_chunks total_num_microbatches = num_microbatches * num_model_chunks
all_warmup_microbatches = False all_warmup_microbatches = False
if forward_only: if forward_only:
num_warmup_microbatches = num_microbatches num_warmup_microbatches = total_num_microbatches
else: else:
# Run all forward passes and then all backward passes if number of # Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages. # microbatches is just the number of pipeline stages.
...@@ -293,8 +399,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, ...@@ -293,8 +399,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
# all workers, followed by more microbatches after depending on # all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can # stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B). # immediately start with 1F1B).
if get_num_microbatches() == pipeline_parallel_size: if num_microbatches == pipeline_parallel_size:
num_warmup_microbatches = num_microbatches num_warmup_microbatches = total_num_microbatches
all_warmup_microbatches = True all_warmup_microbatches = True
else: else:
num_warmup_microbatches = \ num_warmup_microbatches = \
...@@ -302,9 +408,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, ...@@ -302,9 +408,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
num_warmup_microbatches += ( num_warmup_microbatches += (
num_model_chunks - 1) * pipeline_parallel_size num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches, num_warmup_microbatches = min(num_warmup_microbatches,
num_microbatches) total_num_microbatches)
num_microbatches_remaining = \ num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches total_num_microbatches - num_warmup_microbatches
def get_model_chunk_id(microbatch_id, forward): def get_model_chunk_id(microbatch_id, forward):
"""Helper method to get the model chunk ID given the iteration number.""" """Helper method to get the model chunk ID given the iteration number."""
...@@ -319,10 +425,10 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, ...@@ -319,10 +425,10 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
(run set_virtual_pipeline_model_parallel_rank() before calling (run set_virtual_pipeline_model_parallel_rank() before calling
forward_step()).""" forward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
# forward step # forward step
if mpu.is_pipeline_first_stage(): if parallel_state.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == \ if len(input_tensors[model_chunk_id]) == \
len(output_tensors[model_chunk_id]): len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None) input_tensors[model_chunk_id].append(None)
...@@ -330,7 +436,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, ...@@ -330,7 +436,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
output_tensor = forward_step(forward_step_func, output_tensor = forward_step(forward_step_func,
data_iterator[model_chunk_id], data_iterator[model_chunk_id],
model[model_chunk_id], model[model_chunk_id],
input_tensor, num_microbatches,
input_tensor,
forward_data_store, forward_data_store,
timers, timers,
collect_non_loss_data) collect_non_loss_data)
...@@ -348,41 +455,42 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, ...@@ -348,41 +455,42 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
(run set_virtual_pipeline_model_parallel_rank() before calling (run set_virtual_pipeline_model_parallel_rank() before calling
backward_step()).""" backward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if mpu.is_pipeline_last_stage(): if parallel_state.is_pipeline_last_stage():
if len(output_tensor_grads[model_chunk_id]) == 0: if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None) output_tensor_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0) input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0) output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = \ input_tensor_grad = \
backward_step(optimizer, backward_step(grad_scaler,
input_tensor, input_tensor,
output_tensor, output_tensor,
output_tensor_grad, output_tensor_grad,
model_type,
timers) timers)
return input_tensor_grad return input_tensor_grad
# Run warmup forward passes. # Run warmup forward passes.
mpu.set_virtual_pipeline_model_parallel_rank(0) parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append( input_tensors[0].append(
p2p_communication.recv_forward(tensor_shape, timers=timers)) p2p_communication.recv_forward(tensor_shape, dtype, timers=timers))
for k in range(num_warmup_microbatches): for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k) output_tensor = forward_step_helper(k)
# Determine if tensor should be received from previous stage. # Determine if tensor should be received from previous stage.
next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True) next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
recv_prev = True recv_prev = True
if mpu.is_pipeline_first_stage(ignore_virtual=True): if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
if next_forward_model_chunk_id == 0: if next_forward_model_chunk_id == 0:
recv_prev = False recv_prev = False
if k == (num_microbatches - 1): if k == (total_num_microbatches - 1):
recv_prev = False recv_prev = False
# Don't send tensor downstream if on last stage. # Don't send tensor downstream if on last stage.
if mpu.is_pipeline_last_stage(): if parallel_state.is_pipeline_last_stage():
output_tensor = None output_tensor = None
# Send and receive tensors as appropriate (send tensors computed # Send and receive tensors as appropriate (send tensors computed
...@@ -391,20 +499,20 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, ...@@ -391,20 +499,20 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
not all_warmup_microbatches: not all_warmup_microbatches:
input_tensor_grad = None input_tensor_grad = None
recv_next = True recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True): if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False recv_next = False
input_tensor, output_tensor_grad = \ input_tensor, output_tensor_grad = \
p2p_communication.send_forward_backward_recv_forward_backward( p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next, recv_prev=recv_prev, recv_next=recv_next,
tensor_shape=tensor_shape, tensor_shape=tensor_shape, dtype=dtype,
timers=timers) timers=timers)
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad) output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
else: else:
input_tensor = \ input_tensor = \
p2p_communication.send_forward_recv_forward( p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev, output_tensor, recv_prev=recv_prev,
tensor_shape=tensor_shape, tensor_shape=tensor_shape, dtype=dtype,
timers=timers) timers=timers)
input_tensors[next_forward_model_chunk_id].append(input_tensor) input_tensors[next_forward_model_chunk_id].append(input_tensor)
deallocate_output_tensor(output_tensor) deallocate_output_tensor(output_tensor)
...@@ -425,19 +533,19 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, ...@@ -425,19 +533,19 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
# Determine if current stage has anything to send in either direction, # Determine if current stage has anything to send in either direction,
# otherwise set tensor to None. # otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if mpu.is_pipeline_last_stage(): if parallel_state.is_pipeline_last_stage():
output_tensor = None output_tensor = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if mpu.is_pipeline_first_stage(): if parallel_state.is_pipeline_first_stage():
input_tensor_grad = None input_tensor_grad = None
# Determine if peers are sending, and where in data structure to put # Determine if peers are sending, and where in data structure to put
# received tensors. # received tensors.
recv_prev = True recv_prev = True
if mpu.is_pipeline_first_stage(ignore_virtual=True): if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
# First stage is ahead of last stage by (pipeline_parallel_size - 1). # First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id = get_model_chunk_id( next_forward_model_chunk_id = get_model_chunk_id(
forward_k - (pipeline_parallel_size - 1), forward=True) forward_k - (pipeline_parallel_size - 1), forward=True)
...@@ -449,7 +557,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, ...@@ -449,7 +557,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
forward=True) forward=True)
recv_next = True recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True): if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1). # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id = get_model_chunk_id( next_backward_model_chunk_id = get_model_chunk_id(
backward_k - (pipeline_parallel_size - 1), forward=False) backward_k - (pipeline_parallel_size - 1), forward=False)
...@@ -470,7 +578,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, ...@@ -470,7 +578,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
p2p_communication.send_forward_backward_recv_forward_backward( p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next, recv_prev=recv_prev, recv_next=recv_next,
tensor_shape=tensor_shape, timers=timers) tensor_shape=tensor_shape, dtype=dtype, timers=timers)
deallocate_output_tensor(output_tensor) deallocate_output_tensor(output_tensor)
# Put input_tensor and output_tensor_grad in data structures in the # Put input_tensor and output_tensor_grad in data structures in the
...@@ -486,25 +594,29 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, ...@@ -486,25 +594,29 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
if all_warmup_microbatches: if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append( output_tensor_grads[num_model_chunks-1].append(
p2p_communication.recv_backward(tensor_shape, timers=timers)) p2p_communication.recv_backward(tensor_shape, timers=timers))
for k in range(num_microbatches_remaining, num_microbatches): for k in range(num_microbatches_remaining, total_num_microbatches):
input_tensor_grad = backward_step_helper(k) input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False) next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
recv_next = True recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True): if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
if next_backward_model_chunk_id == (num_model_chunks - 1): if next_backward_model_chunk_id == (num_model_chunks - 1):
recv_next = False recv_next = False
if k == (num_microbatches - 1): if k == (total_num_microbatches - 1):
recv_next = False recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append( output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward( p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next=recv_next, input_tensor_grad, recv_next=recv_next,
tensor_shape=tensor_shape, tensor_shape=tensor_shape, dtype=dtype,
timers=timers)) timers=timers))
return forward_data_store return forward_data_store
def get_tensor_shapes(*,
def get_tensor_shapes(rank, model_type): rank: int,
model_type: ModelType,
tensor_shape: Shape,
decoder_seq_length: int,
sequence_parallel: bool):
# Determine right tensor sizes (based on position of rank with respect to split # Determine right tensor sizes (based on position of rank with respect to split
# rank) and model size. # rank) and model size.
# Send two tensors if model is T5 and rank is in decoder stage: # Send two tensors if model is T5 and rank is in decoder stage:
...@@ -513,48 +625,50 @@ def get_tensor_shapes(rank, model_type): ...@@ -513,48 +625,50 @@ def get_tensor_shapes(rank, model_type):
# If model is T5 and rank is at the boundary: # If model is T5 and rank is at the boundary:
# send one tensor (post-transpose from encoder). # send one tensor (post-transpose from encoder).
# Otherwise, send one tensor (pre-transpose). # Otherwise, send one tensor (pre-transpose).
args = get_args()
tensor_shapes = [] tensor_shapes = []
if args.sequence_parallel: assert (
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size() len(tensor_shape) == 3
else: ), f"`tensor_shape` should be [sequence_length, micro_batch_size, hidden_size] but {tensor_shape}"
seq_length = args.seq_length
seq_length, micro_batch_size, hidden_size = tensor_shape
if sequence_parallel:
seq_length = seq_length // parallel_state.get_tensor_model_parallel_world_size()
if model_type == ModelType.encoder_and_decoder: if model_type == ModelType.encoder_and_decoder:
if args.sequence_parallel: if sequence_parallel:
decoder_seq_length = args.decoder_seq_length // mpu.get_tensor_model_parallel_world_size() decoder_seq_length = decoder_seq_length // parallel_state.get_tensor_model_parallel_world_size()
else:
decoder_seq_length = args.decoder_seq_length
if mpu.is_pipeline_stage_before_split(rank): if parallel_state.is_pipeline_stage_before_split(rank):
tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size)) tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
else: else:
tensor_shapes.append((decoder_seq_length, args.micro_batch_size, args.hidden_size)) tensor_shapes.append((decoder_seq_length, micro_batch_size, hidden_size))
tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size)) tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
else: else:
tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size)) tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
return tensor_shapes return tensor_shapes
def recv_forward(tensor_shapes, timers):
def recv_forward(tensor_shapes, dtype, timers):
input_tensors = [] input_tensors = []
for tensor_shape in tensor_shapes: for tensor_shape in tensor_shapes:
if tensor_shape is None: if tensor_shape is None:
input_tensors.append(None) input_tensors.append(None)
else: else:
input_tensors.append(p2p_communication.recv_forward(tensor_shape, input_tensors.append(p2p_communication.recv_forward(tensor_shape, dtype,
timers=timers)) timers=timers))
return input_tensors return input_tensors
def recv_backward(tensor_shapes, timers): def recv_backward(tensor_shapes, dtype, timers):
output_tensor_grads = [] output_tensor_grads = []
for tensor_shape in tensor_shapes: for tensor_shape in tensor_shapes:
if tensor_shape is None: if tensor_shape is None:
output_tensor_grads.append(None) output_tensor_grads.append(None)
else: else:
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape, output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape, dtype,
timers=timers)) timers=timers))
return output_tensor_grads return output_tensor_grads
...@@ -565,7 +679,7 @@ def send_forward(output_tensors, tensor_shapes, timers): ...@@ -565,7 +679,7 @@ def send_forward(output_tensors, tensor_shapes, timers):
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes): for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None: if tensor_shape is None:
continue continue
p2p_communication.send_forward(output_tensor, tensor_shape, timers=timers) p2p_communication.send_forward(output_tensor, timers=timers)
def send_backward(input_tensor_grads, tensor_shapes, timers): def send_backward(input_tensor_grads, tensor_shapes, timers):
...@@ -574,10 +688,10 @@ def send_backward(input_tensor_grads, tensor_shapes, timers): ...@@ -574,10 +688,10 @@ def send_backward(input_tensor_grads, tensor_shapes, timers):
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes): for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None: if tensor_shape is None:
continue continue
p2p_communication.send_backward(input_tensor_grad, tensor_shape, timers=timers) p2p_communication.send_backward(input_tensor_grad, timers=timers)
def send_forward_recv_backward(output_tensors, tensor_shapes, timers): def send_forward_recv_backward(output_tensors, tensor_shapes, dtype, timers):
if not isinstance(output_tensors, list): if not isinstance(output_tensors, list):
output_tensors = [output_tensors] output_tensors = [output_tensors]
output_tensor_grads = [] output_tensor_grads = []
...@@ -586,12 +700,12 @@ def send_forward_recv_backward(output_tensors, tensor_shapes, timers): ...@@ -586,12 +700,12 @@ def send_forward_recv_backward(output_tensors, tensor_shapes, timers):
output_tensor_grads.append(None) output_tensor_grads.append(None)
continue continue
output_tensor_grad = p2p_communication.send_forward_recv_backward( output_tensor_grad = p2p_communication.send_forward_recv_backward(
output_tensor, tensor_shape, timers=timers) output_tensor, tensor_shape, dtype, timers=timers)
output_tensor_grads.append(output_tensor_grad) output_tensor_grads.append(output_tensor_grad)
return output_tensor_grads return output_tensor_grads
def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers): def send_backward_recv_forward(input_tensor_grads, tensor_shapes, dtype, timers):
if not isinstance(input_tensor_grads, list): if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads] input_tensor_grads = [input_tensor_grads]
input_tensors = [] input_tensors = []
...@@ -600,44 +714,55 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers): ...@@ -600,44 +714,55 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
input_tensors.append(None) input_tensors.append(None)
continue continue
input_tensor = p2p_communication.send_backward_recv_forward( input_tensor = p2p_communication.send_backward_recv_forward(
input_tensor_grad, tensor_shape, timers=timers) input_tensor_grad, tensor_shape, dtype, timers=timers)
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
return input_tensors return input_tensors
def forward_backward_pipelining_without_interleaving(forward_step_func, def forward_backward_pipelining_without_interleaving(*,
forward_step_func,
data_iterator, data_iterator,
model, model: Union[torch.nn.Module, List[torch.nn.Module]],
optimizer, num_microbatches: int,
timers, dtype: torch.dtype,
forward_only, tensor_shape: Shape,
collect_non_loss_data=False): decoder_seq_length: Optional[int] = None,
grad_scaler: Callable = None,
sequence_parallel: bool = False,
forward_only: bool = False,
timers: Callable = None,
collect_non_loss_data: bool = False):
"""Run non-interleaved 1F1B schedule, with communication between pipeline """Run non-interleaved 1F1B schedule, with communication between pipeline
stages. stages.
Returns dictionary with losses if the last stage, empty dict otherwise.""" Returns dictionary with losses if the last stage, empty dict otherwise."""
args = get_args()
assert len(model) == 1 assert len(model) == 1
model = model[0] model = model[0]
# Compute number of warmup microbatches. # Compute number of warmup microbatches.
num_microbatches = get_num_microbatches()
num_warmup_microbatches = \ num_warmup_microbatches = \
(mpu.get_pipeline_model_parallel_world_size() - (parallel_state.get_pipeline_model_parallel_world_size() -
mpu.get_pipeline_model_parallel_rank() - 1) parallel_state.get_pipeline_model_parallel_rank() - 1)
num_warmup_microbatches = min( num_warmup_microbatches = min(
num_warmup_microbatches, num_warmup_microbatches,
num_microbatches) num_microbatches)
num_microbatches_remaining = \ num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches num_microbatches - num_warmup_microbatches
unwrapped_model = unwrap_model( model_type = get_model_type(model)
model, (torchDDP, LocalDDP, Float16Module))
model_type = unwrapped_model.model_type rank = parallel_state.get_pipeline_model_parallel_rank()
rank = mpu.get_pipeline_model_parallel_rank() recv_tensor_shapes = get_tensor_shapes(rank=rank-1,
recv_tensor_shapes = get_tensor_shapes(rank-1, model_type) model_type=model_type,
send_tensor_shapes = get_tensor_shapes(rank, model_type) tensor_shape=tensor_shape,
decoder_seq_length=decoder_seq_length,
sequence_parallel=sequence_parallel)
send_tensor_shapes = get_tensor_shapes(rank=rank,
model_type=model_type,
tensor_shape=tensor_shape,
decoder_seq_length=decoder_seq_length,
sequence_parallel=sequence_parallel)
# Input, output tensors only need to be saved when doing backward passes # Input, output tensors only need to be saved when doing backward passes
input_tensors = None input_tensors = None
...@@ -649,10 +774,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, ...@@ -649,10 +774,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
input_tensor = recv_forward(recv_tensor_shapes, timers=timers) input_tensor = recv_forward(recv_tensor_shapes, dtype, timers=timers)
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model, num_microbatches,input_tensor, forward_data_store,timers, collect_non_loss_data)
input_tensor, forward_data_store,
timers, collect_non_loss_data)
send_forward(output_tensor, send_tensor_shapes, timers=timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not forward_only: if not forward_only:
...@@ -664,25 +787,26 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, ...@@ -664,25 +787,26 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
# If all microbatches are run in warmup / cooldown phase, then no need to # If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here. # receive this tensor here.
if num_microbatches_remaining > 0: if num_microbatches_remaining > 0:
input_tensor = recv_forward(recv_tensor_shapes, timers=timers) input_tensor = recv_forward(recv_tensor_shapes, dtype, timers=timers)
# Run 1F1B in steady state. # Run 1F1B in steady state.
for i in range(num_microbatches_remaining): for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1)) last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model, num_microbatches,
input_tensor, forward_data_store, input_tensor, forward_data_store,
timers, collect_non_loss_data) timers, collect_non_loss_data)
if forward_only: if forward_only:
send_forward(output_tensor, send_tensor_shapes, timers=timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not last_iteration: if not last_iteration:
input_tensor = recv_forward(recv_tensor_shapes, timers=timers) input_tensor = recv_forward(recv_tensor_shapes, dtype, timers=timers)
else: else:
output_tensor_grad = \ output_tensor_grad = \
send_forward_recv_backward(output_tensor, send_forward_recv_backward(output_tensor,
send_tensor_shapes, send_tensor_shapes, dtype,
timers=timers) timers=timers)
# Add input_tensor and output_tensor to end of list. # Add input_tensor and output_tensor to end of list.
...@@ -696,8 +820,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, ...@@ -696,8 +820,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
input_tensor_grad = \ input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor, backward_step(grad_scaler, input_tensor, output_tensor,
output_tensor_grad, timers) output_tensor_grad, model_type, timers)
if last_iteration: if last_iteration:
input_tensor = None input_tensor = None
...@@ -705,7 +829,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, ...@@ -705,7 +829,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
else: else:
input_tensor = \ input_tensor = \
send_backward_recv_forward( send_backward_recv_forward(
input_tensor_grad, recv_tensor_shapes, timers=timers) input_tensor_grad, recv_tensor_shapes, dtype, timers=timers)
# Run cooldown backward passes. # Run cooldown backward passes.
if not forward_only: if not forward_only:
...@@ -713,11 +837,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, ...@@ -713,11 +837,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
input_tensor = input_tensors.pop(0) input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
output_tensor_grad = recv_backward(send_tensor_shapes, timers=timers) output_tensor_grad = recv_backward(send_tensor_shapes, dtype, timers=timers)
input_tensor_grad = \ input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor, backward_step(grad_scaler, input_tensor, output_tensor,
output_tensor_grad, timers) output_tensor_grad, model_type, timers)
send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers) send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
......
...@@ -13,6 +13,8 @@ import torch.nn.functional as F ...@@ -13,6 +13,8 @@ import torch.nn.functional as F
import torch.nn.init as init import torch.nn.init as init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.cuda.amp import custom_fwd, custom_bwd
from megatron.core.parallel_state import ( from megatron.core.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
...@@ -214,6 +216,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -214,6 +216,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
"""See linear_with_grad_accumulation_and_async_allreduce""" """See linear_with_grad_accumulation_and_async_allreduce"""
@staticmethod @staticmethod
@custom_fwd
def forward(ctx, input, weight, bias, gradient_accumulation_fusion, def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
async_grad_allreduce, sequence_parallel): async_grad_allreduce, sequence_parallel):
ctx.save_for_backward(input, weight) ctx.save_for_backward(input, weight)
...@@ -243,6 +246,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -243,6 +246,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
return output return output
@staticmethod @staticmethod
@custom_bwd
def backward(ctx, grad_output): def backward(ctx, grad_output):
input, weight = ctx.saved_tensors input, weight = ctx.saved_tensors
use_bias = ctx.use_bias use_bias = ctx.use_bias
...@@ -407,8 +411,8 @@ def linear_with_grad_accumulation_and_async_allreduce( ...@@ -407,8 +411,8 @@ def linear_with_grad_accumulation_and_async_allreduce(
"maximum speedup") "maximum speedup")
linear_with_grad_accumulation_and_async_allreduce.warned = True linear_with_grad_accumulation_and_async_allreduce.warned = True
with torch.cuda.amp.autocast(enabled=False): return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
linear_with_grad_accumulation_and_async_allreduce.warned = False linear_with_grad_accumulation_and_async_allreduce.warned = False
class ColumnParallelLinear(torch.nn.Module): class ColumnParallelLinear(torch.nn.Module):
......
...@@ -20,6 +20,21 @@ def divide(numerator, denominator): ...@@ -20,6 +20,21 @@ def divide(numerator, denominator):
ensure_divisibility(numerator, denominator) ensure_divisibility(numerator, denominator)
return numerator // denominator return numerator // denominator
def get_attr_wrapped_model(model, attr):
"""Get an attribute from a wrapped model"""
if isinstance(model, list):
raise RuntimeError("_get_attr_wrapped_model given a list of models")
while not hasattr(model, attr):
if not hasattr(model, "module"):
raise RuntimeError(f"_get_attr_wrapped_model couldn't find attribute {attr}")
model = model.module
return getattr(model, attr)
def get_model_type(model):
return get_attr_wrapped_model(model, 'model_type')
class GlobalMemoryBuffer: class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations. """Global buffer to avoid dynamic memory allocations.
......
...@@ -8,4 +8,3 @@ from .gpt_model import GPTModel ...@@ -8,4 +8,3 @@ from .gpt_model import GPTModel
from .t5_model import T5Model from .t5_model import T5Model
from .language_model import get_language_model from .language_model import get_language_model
from .module import Float16Module from .module import Float16Module
from .enums import ModelType
...@@ -2,10 +2,6 @@ ...@@ -2,10 +2,6 @@
import enum import enum
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
class LayerType(enum.Enum): class LayerType(enum.Enum):
encoder = 1 encoder = 1
decoder = 2 decoder = 2
......
...@@ -20,7 +20,8 @@ from megatron import get_args, get_retro_args, get_tensorboard_writer ...@@ -20,7 +20,8 @@ from megatron import get_args, get_retro_args, get_tensorboard_writer
from megatron.core import parallel_state from megatron.core import parallel_state
from megatron.core import tensor_parallel from megatron.core import tensor_parallel
from megatron.core import utils as core_utils from megatron.core import utils as core_utils
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType from megatron.core.enums import ModelType
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import LayerNorm from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
......
...@@ -9,7 +9,8 @@ import torch.nn.functional as F ...@@ -9,7 +9,8 @@ import torch.nn.functional as F
from megatron import get_timers, get_args, core, get_num_microbatches from megatron import get_timers, get_args, core, get_num_microbatches
from .module import MegatronModule from .module import MegatronModule
from megatron.core import mpu, tensor_parallel from megatron.core import mpu, tensor_parallel
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType from megatron.core.enums import ModelType
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import LayerNorm from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
......
...@@ -25,8 +25,8 @@ from megatron import print_rank_last ...@@ -25,8 +25,8 @@ from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.model import Float16Module from megatron.model import Float16Module
from megatron.model import ModelType
from megatron.model import GPTModel from megatron.model import GPTModel
from megatron.core.enums import ModelType
from megatron.optimizer import get_megatron_optimizer from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard from megatron.initialize import write_args_to_tensorboard
...@@ -37,7 +37,7 @@ from megatron.utils import check_adlr_autoresume_termination ...@@ -37,7 +37,7 @@ from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm from megatron.utils import calc_params_l2_norm
from megatron.schedules import get_forward_backward_func from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.utils import report_memory from megatron.utils import report_memory
from megatron.model.vision.knn_monitor import compute_feature_bank from megatron.model.vision.knn_monitor import compute_feature_bank
...@@ -400,6 +400,7 @@ def setup_model_and_optimizer(model_provider_func, ...@@ -400,6 +400,7 @@ def setup_model_and_optimizer(model_provider_func,
return model, optimizer, opt_param_scheduler return model, optimizer, opt_param_scheduler
def train_step(forward_step_func, data_iterator, def train_step(forward_step_func, data_iterator,
model, optimizer, opt_param_scheduler): model, optimizer, opt_param_scheduler):
"""Single training step.""" """Single training step."""
...@@ -418,8 +419,16 @@ def train_step(forward_step_func, data_iterator, ...@@ -418,8 +419,16 @@ def train_step(forward_step_func, data_iterator,
forward_backward_func = get_forward_backward_func() forward_backward_func = get_forward_backward_func()
fwd_bwd_timers = timers if args.timing_log_level > 1 else None fwd_bwd_timers = timers if args.timing_log_level > 1 else None
losses_reduced = forward_backward_func( losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model, forward_step_func=forward_step_func,
optimizer, fwd_bwd_timers, forward_only=False) data_iterator=data_iterator,
model=model,
num_microbatches=get_num_microbatches(),
dtype=args.params_dtype,
tensor_shape=(args.seq_length, args.micro_batch_size, args.hidden_size),
grad_scaler=optimizer.scale_loss,
sequence_parallel=args.sequence_parallel,
forward_only=False,
timers=fwd_bwd_timers)
timers('forward-backward').stop() timers('forward-backward').stop()
# Empty unused memory. # Empty unused memory.
...@@ -794,8 +803,15 @@ def evaluate(forward_step_func, ...@@ -794,8 +803,15 @@ def evaluate(forward_step_func,
forward_backward_func = get_forward_backward_func() forward_backward_func = get_forward_backward_func()
loss_dicts = forward_backward_func( loss_dicts = forward_backward_func(
forward_step_func, data_iterator, model, optimizer=None, forward_step_func=forward_step_func,
timers=None, forward_only=True) data_iterator=data_iterator,
model=model,
num_microbatches=get_num_microbatches(),
dtype=args.params_dtype,
tensor_shape=(args.seq_length, args.micro_batch_size, args.hidden_size),
sequence_parallel=args.sequence_parallel,
forward_only=True,
timers=None)
# Empty unused memory # Empty unused memory
if args.empty_unused_memory_level >= 1: if args.empty_unused_memory_level >= 1:
......
...@@ -11,8 +11,9 @@ from megatron import get_args ...@@ -11,8 +11,9 @@ from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron.core import tensor_parallel from megatron.core import tensor_parallel
from megatron.core.enums import ModelType
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel, ModelType from megatron.model import BertModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
......
...@@ -9,8 +9,9 @@ from megatron import print_rank_0 ...@@ -9,8 +9,9 @@ from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron.core import tensor_parallel from megatron.core import tensor_parallel
from megatron.core.enums import ModelType
from megatron.data.gpt_dataset import build_train_valid_test_datasets from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel, ModelType from megatron.model import GPTModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
......
...@@ -13,9 +13,9 @@ from megatron import get_args ...@@ -13,9 +13,9 @@ from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron.core import mpu from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron.data.biencoder_dataset_utils import get_ict_batch from megatron.data.biencoder_dataset_utils import get_ict_batch
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import ModelType
from megatron.model.biencoder_model import biencoder_model_provider from megatron.model.biencoder_model import biencoder_model_provider
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
......
...@@ -10,7 +10,8 @@ from megatron import get_timers ...@@ -10,7 +10,8 @@ from megatron import get_timers
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.core import mpu, tensor_parallel from megatron.core import mpu, tensor_parallel
from megatron.model import GPTModel, ModelType from megatron.core.enums import ModelType
from megatron.model import GPTModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from tools.retro.pretraining.retro_dataset import get_retro_datasets from tools.retro.pretraining.retro_dataset import get_retro_datasets
......
...@@ -12,8 +12,9 @@ from megatron import ( ...@@ -12,8 +12,9 @@ from megatron import (
print_rank_0 print_rank_0
) )
from megatron.core import tensor_parallel from megatron.core import tensor_parallel
from megatron.core.enums import ModelType
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import T5Model, ModelType from megatron.model import T5Model
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
......
...@@ -6,8 +6,8 @@ import torch ...@@ -6,8 +6,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from functools import partial from functools import partial
from megatron import get_args, get_timers, print_rank_0 from megatron import get_args, get_timers, print_rank_0
from megatron.core.enums import ModelType
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model import ModelType
from megatron.model.vision.classification import VitClassificationModel from megatron.model.vision.classification import VitClassificationModel
from megatron.model.vision.classification import MitClassificationModel from megatron.model.vision.classification import MitClassificationModel
from megatron.training import pretrain from megatron.training import pretrain
......
...@@ -7,6 +7,7 @@ import numpy as np ...@@ -7,6 +7,7 @@ import numpy as np
import torch.distributed as dist import torch.distributed as dist
from functools import partial from functools import partial
from megatron import get_args, get_timers, print_rank_0 from megatron import get_args, get_timers, print_rank_0
from megatron.core.enums import ModelType
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.dino import DINOPretrainModel from megatron.model.vision.dino import DINOPretrainModel
from megatron.model.vision.knn_monitor import knn_predict, get_feature_bank from megatron.model.vision.knn_monitor import knn_predict, get_feature_bank
...@@ -15,7 +16,6 @@ from megatron.utils import average_losses_across_data_parallel_group, unwrap_mod ...@@ -15,7 +16,6 @@ from megatron.utils import average_losses_across_data_parallel_group, unwrap_mod
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module from megatron.model import Float16Module
from megatron.model import ModelType
def model_provider(pre_process=True, post_process=True): def model_provider(pre_process=True, post_process=True):
"""Build the model.""" """Build the model."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment