Commit 79906517 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Merge remote-tracking branch 'upstream/master' into IFU-master-2021-12-08

parents cc92a4b4 aa756cec
# coding=utf-8 # coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
from .fused_softmax import FusedScaleMaskSoftmax from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax
__all__ = [ __all__ = [
"FusedScaleMaskSoftmax", "FusedScaleMaskSoftmax",
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import torch import torch
from apex._autocast_utils import _cast_if_autocast_enabled from apex._autocast_utils import _cast_if_autocast_enabled
from ..enums import AttnMaskType from apex.transformer.enums import AttnMaskType
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
......
from typing import Optional
import logging
import os
import threading
def get_transformer_logger(name: str) -> logging.Logger:
name_wo_ext = os.path.splitext(name)[0]
return logging.getLogger(name_wo_ext)
def set_logging_level(verbosity) -> None:
"""Change logging severity.
Args:
verbosity
"""
from apex import _library_root_logger
_library_root_logger.setLevel(verbosity)
# coding=utf-8 # coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,35 +15,41 @@ ...@@ -15,35 +15,41 @@
"""Megatron number of micro-batches calculators.""" """Megatron number of micro-batches calculators."""
from abc import ABC from abc import ABC
from abc import abstractmethod from abc import abstractmethod
from typing import Optional, List
def build_num_microbatches_calculator(args): def build_num_microbatches_calculator(
rank: int,
rampup_batch_size: Optional[List[int]],
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
):
# Constant num micro-batches. # Constant num micro-batches.
if args.rampup_batch_size is None: if rampup_batch_size is None:
num_microbatches_calculator = ConstantNumMicroBatches( num_microbatches_calculator = ConstantNumMicroBatches(
args.global_batch_size, args.micro_batch_size, args.data_parallel_size global_batch_size, micro_batch_size, data_parallel_size
) )
if args.rank == 0: if rank == 0:
print( print(
"setting number of micro-batches to constant {}".format(num_microbatches_calculator.get()), flush=True "setting number of micro-batches to constant {}".format(num_microbatches_calculator.get()), flush=True
) )
else: else:
assert len(args.rampup_batch_size) == 3, ( assert len(rampup_batch_size) == 3, (
"expected the following " "expected the following "
"format: --rampup-batch-size <start batch size> " "format: --rampup-batch-size <start batch size> "
"<batch size incerement> <ramp-up samples>" "<batch size incerement> <ramp-up samples>"
) )
start_batch_size = int(args.rampup_batch_size[0]) start_batch_size = int(rampup_batch_size[0])
batch_size_increment = int(args.rampup_batch_size[1]) batch_size_increment = int(rampup_batch_size[1])
ramup_samples = int(args.rampup_batch_size[2]) ramup_samples = int(rampup_batch_size[2])
if args.rank == 0: if rank == 0:
print( print(
"will use batch size rampup starting from global batch " "will use batch size rampup starting from global batch "
"size {} to global batch size {} with batch size increments " "size {} to global batch size {} with batch size increments "
"{} over {} samples.".format( "{} over {} samples.".format(
start_batch_size, args.global_batch_size, batch_size_increment, ramup_samples start_batch_size, global_batch_size, batch_size_increment, ramup_samples
), ),
flush=True, flush=True,
) )
...@@ -51,9 +57,9 @@ def build_num_microbatches_calculator(args): ...@@ -51,9 +57,9 @@ def build_num_microbatches_calculator(args):
start_batch_size, start_batch_size,
batch_size_increment, batch_size_increment,
ramup_samples, ramup_samples,
args.global_batch_size, global_batch_size,
args.micro_batch_size, micro_batch_size,
args.data_parallel_size, data_parallel_size,
) )
return num_microbatches_calculator return num_microbatches_calculator
...@@ -86,6 +92,8 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator): ...@@ -86,6 +92,8 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator):
assert self.num_micro_batches >= 1 assert self.num_micro_batches >= 1
self.current_global_batch_size = global_batch_size self.current_global_batch_size = global_batch_size
self.micro_batch_size = micro_batch_size
def update(self, consumed_samples, consistency_check): def update(self, consumed_samples, consistency_check):
pass pass
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,11 +13,13 @@ ...@@ -13,11 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Model and data parallel groups.""" """Model and data parallel groups."""
from typing import Tuple
import torch import torch
# TODO (mkozuki): Consider dissecting utils as this utils import is here # TODO (mkozuki): Consider dissecting utils as this utils import is here
# only for ensure_divisibility # only for ensure_divisibility
from .tensor_parallel import utils from apex.transformer.utils import ensure_divisibility
# Intra-layer model parallel group that the current rank belongs to. # Intra-layer model parallel group that the current rank belongs to.
...@@ -40,6 +42,9 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None ...@@ -40,6 +42,9 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None _MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None
# A list of global ranks for each pipeline group to ease calculation of the source # A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage # rank when broadcasting from the first or last pipeline stage
_PIPELINE_GLOBAL_RANKS = None _PIPELINE_GLOBAL_RANKS = None
...@@ -76,23 +81,26 @@ def initialize_model_parallel( ...@@ -76,23 +81,26 @@ def initialize_model_parallel(
with a total of 16 GPUs, rank 0 to 7 belong to the first box and with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box. ranks 8 to 15 belong to the second box.
""" """
if torch.distributed.get_rank() == 0:
print("> initializing tensor model parallel with size {}".format(tensor_model_parallel_size_))
print("> initializing pipeline model parallel with size {}".format(pipeline_model_parallel_size_))
# Get world size and rank. Ensure some consistencies. # Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized() assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size) tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size) pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)
# TODO (mkozuki): Consider moving `ensure_divisibility` to this file. ensure_divisibility(world_size, tensor_model_parallel_size * pipeline_model_parallel_size)
utils.ensure_divisibility(world_size, tensor_model_parallel_size * pipeline_model_parallel_size)
data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size) data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)
if torch.distributed.get_rank() == 0:
print("> initializing tensor model parallel with size {}".format(tensor_model_parallel_size))
print("> initializing pipeline model parallel with size {}".format(pipeline_model_parallel_size))
print("> initializing data parallel with size {}".format(data_parallel_size))
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
num_data_parallel_groups = world_size // data_parallel_size num_data_parallel_groups = world_size // data_parallel_size
if virtual_pipeline_model_parallel_size_ is not None: if virtual_pipeline_model_parallel_size_ is not None:
assert pipeline_model_parallel_size_ > 2, \
'pipeline-model-parallel size should be greater than 2 with ' \
'interleaved schedule'
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
...@@ -138,6 +146,7 @@ def initialize_model_parallel( ...@@ -138,6 +146,7 @@ def initialize_model_parallel(
global _PIPELINE_GLOBAL_RANKS global _PIPELINE_GLOBAL_RANKS
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, "pipeline model parallel group is already initialized" assert _PIPELINE_MODEL_PARALLEL_GROUP is None, "pipeline model parallel group is already initialized"
global _EMBEDDING_GROUP global _EMBEDDING_GROUP
global _EMBEDDING_GLOBAL_RANKS
assert _EMBEDDING_GROUP is None, "embedding group is already initialized" assert _EMBEDDING_GROUP is None, "embedding group is already initialized"
for i in range(num_pipeline_model_parallel_groups): for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, num_pipeline_model_parallel_groups) ranks = range(i, world_size, num_pipeline_model_parallel_groups)
...@@ -154,6 +163,19 @@ def initialize_model_parallel( ...@@ -154,6 +163,19 @@ def initialize_model_parallel(
group = torch.distributed.new_group(embedding_ranks) group = torch.distributed.new_group(embedding_ranks)
if rank in embedding_ranks: if rank in embedding_ranks:
_EMBEDDING_GROUP = group _EMBEDDING_GROUP = group
if rank in ranks:
_EMBEDDING_GLOBAL_RANKS = embedding_ranks
def get_rank_info() -> Tuple[int, int, int]:
"""Returns a tuple of (tensor, pipeline, data)-parallel-rank for logger."""
if model_parallel_is_initialized():
return (
get_tensor_model_parallel_rank(),
get_pipeline_model_parallel_rank(),
# get_virtual_pipeline_model_parallel_rank(),
get_data_parallel_rank(),
)
return (0, 0, 0)
def model_parallel_is_initialized(): def model_parallel_is_initialized():
...@@ -193,6 +215,22 @@ def get_embedding_group(): ...@@ -193,6 +215,22 @@ def get_embedding_group():
return _EMBEDDING_GROUP return _EMBEDDING_GROUP
def is_rank_in_embedding_group(ignore_virtual=False):
"""Return true if current rank is in embedding group, False otherwise."""
rank = torch.distributed.get_rank()
global _EMBEDDING_GLOBAL_RANKS
if ignore_virtual:
return rank in _EMBEDDING_GLOBAL_RANKS
if rank in _EMBEDDING_GLOBAL_RANKS:
if rank == _EMBEDDING_GLOBAL_RANKS[0]:
return is_pipeline_first_stage(ignore_virtual=False)
elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
return is_pipeline_last_stage(ignore_virtual=False)
else:
return True
return False
def set_tensor_model_parallel_world_size(world_size): def set_tensor_model_parallel_world_size(world_size):
"""Set the tensor model parallel size""" """Set the tensor model parallel size"""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
...@@ -344,3 +382,15 @@ def destroy_model_parallel(): ...@@ -344,3 +382,15 @@ def destroy_model_parallel():
_DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None
global _EMBEDDING_GROUP global _EMBEDDING_GROUP
_EMBEDDING_GROUP = None _EMBEDDING_GROUP = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
from apex.transformer.pipeline_parallel.schedules import get_forward_backward_func
from apex.transformer.pipeline_parallel.schedules.common import build_model
__all__ = [
"get_forward_backward_func",
"build_model",
]
import time
import torch
class _Timer:
"""Timer."""
def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
def start(self):
"""Start the timer."""
assert not self.started_, "timer has already been started"
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, "timer is not started"
torch.cuda.synchronize()
self.elapsed_ += time.time() - self.start_time
self.started_ = False
def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False
def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
return elapsed_
class _Timers:
"""Group of timers."""
def __init__(self):
self.timers = {}
def __call__(self, name):
if name not in self.timers:
self.timers[name] = _Timer(name)
return self.timers[name]
def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(name + "-time", value, iteration)
def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers."""
assert normalizer > 0.0
string = "time (ms)"
for name in names:
elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer
string += " | {}: {:.2f}".format(name, elapsed_time)
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1):
print(string, flush=True)
else:
print(string, flush=True)
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import reduce
import operator
from typing import Union, Optional, Tuple
import warnings
import torch
from apex._autocast_utils import _get_current_dtype
from apex.transformer import parallel_state
from apex.transformer.utils import split_tensor_into_1d_equal_chunks
from apex.transformer.utils import gather_split_1d_tensor
from apex.transformer.pipeline_parallel.utils import Shape
from apex.transformer.pipeline_parallel._timers import _Timers
def _run_p2pops(
tensor_send_prev: Union[torch.Tensor, None],
tensor_send_next: Union[torch.Tensor, None],
tensor_recv_prev: Union[torch.Tensor, None],
tensor_recv_next: Union[torch.Tensor, None],
):
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor_send_prev,
parallel_state.get_pipeline_model_parallel_prev_rank(),
)
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_prev,
parallel_state.get_pipeline_model_parallel_prev_rank(),
)
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor_send_next,
parallel_state.get_pipeline_model_parallel_next_rank(),
)
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_next,
parallel_state.get_pipeline_model_parallel_next_rank(),
)
ops.append(recv_next_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
def _communicate(
tensor_send_next: Optional[torch.Tensor],
tensor_send_prev: Optional[torch.Tensor],
recv_prev: bool,
recv_next: bool,
tensor_shape: Optional[Shape] = None,
override_scatter_gather_tensors_in_pipeline: bool = False,
dtype_: torch.dtype = torch.float,
*,
scatter_gather_tensors_in_pipeline: bool = True,
params_dtype: Optional[torch.dtype] = None,
fp32_residual_connection: bool = False,
) -> Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None]]:
"""Base function for communication of tensors between stages.
Args:
tensor_send_next: tensor to send to next rank (no tensor sent if set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if set to None).
recv_prev: boolean for whether tensor should be received from previous rank.
recv_next: boolean for whether tensor should be received from next rank.
tensor_shape: optional, use when the input sequence contains less tokens than the default sequence length
override_scatter_gather_tensors_in_pipeline:
optional, this is used when tensor_shape is provided to override scatter gather tensors
dtype_: This is used when tensor_shape is provided and what is the type of tensor_shape
Keyword args:
scatter_gather_tensors_in_pipeline: Optional. If :obj:`True`, use scatter/gather to optimize communication of tensors.
params_dtype: Optional and legacy. Defaults to torch.float. If you manually call `.half()` or `.bfloat16()` on
your model deliberately, pass this argument.
fp32_residual_connection: Optional. If :obj:`True`, move residual connections to fp32.
Returns:
tuple containing
- tensor_recv_prev: `torch.Tensor` if `recv_prev` is :obj:`True`, `None` otherwise.
- tensor_recv_next: `torch.Tensor` if `recv_next` is :obj:`True`, `None` otherwise.
"""
# Create placeholder tensors for receive in forward and backward directions if needed.
tensor_recv_prev = None
tensor_recv_next = None
if tensor_shape is None:
# In megatron, `tensor_shape` is set to `(args.seq_length, args.micro_batch_size, args.hidden_size)`
raise RuntimeError(
"`tensor_shape` must be specified. Common `tensor_shape` is `(seq_length, micro_batch_size, hidden_size)`")
if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = (reduce(operator.mul, tensor_shape, 1) // parallel_state.get_tensor_model_parallel_world_size(),)
else:
tensor_chunk_shape = tensor_shape
# NOTE(mkozuki): In PyTorch AMP, i.e. `torch.cuda.amp.autocast` context, activation tensors can be either FP32,
# FP16, or BF16 and there's no way to tell the dtypes of tensors on different devices in general.
# It might be possible if we restrict model architecture.
# dtype = params_dtype or torch.float
# if fp32_residual_connection:
# dtype = torch.float
# if dtype_ is not None:
# dtype = dtype_
# requires_grad = False
if dtype_ != torch.float32 or params_dtype is not None:
if torch.distributed.get_rank() == 0:
warnings.warn("Tensor P2P communications are executed in FP32")
dtype = torch.float32
requires_grad = True
if recv_prev:
tensor_recv_prev = torch.empty(
tensor_chunk_shape,
requires_grad=requires_grad,
device=torch.cuda.current_device(),
dtype=dtype,
)
if recv_next:
tensor_recv_next = torch.empty(
tensor_chunk_shape,
requires_grad=requires_grad,
device=torch.cuda.current_device(),
dtype=dtype,
)
# Split tensor into smaller chunks if using scatter-gather optimization.
if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
if tensor_send_next is not None:
tensor_send_next = split_tensor_into_1d_equal_chunks(tensor_send_next)
if tensor_send_prev is not None:
tensor_send_prev = split_tensor_into_1d_equal_chunks(tensor_send_prev)
# Send tensors in both the forward and backward directions as appropriate.
_run_p2pops(tensor_send_prev, tensor_send_next, tensor_recv_prev, tensor_recv_next)
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
# If using scatter-gather optimization, gather smaller chunks.
if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
if recv_prev:
tensor_recv_prev = (
gather_split_1d_tensor(tensor_recv_prev)
.view(tensor_shape)
.requires_grad_()
)
if recv_next:
tensor_recv_next = (
gather_split_1d_tensor(tensor_recv_next)
.view(tensor_shape)
.requires_grad_()
)
return tensor_recv_prev, tensor_recv_next
def recv_forward(
tensor_shape: Shape,
override_scatter_gather_tensors_in_pipeline: bool = False,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> torch.Tensor:
"""Receive tensor from previous rank in pipeline (forward receive)."""
if parallel_state.is_pipeline_first_stage():
return None
if timers is not None:
timers("forward-recv").start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=True,
recv_next=False,
tensor_shape=tensor_shape,
override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("forward-recv").stop()
return input_tensor
def recv_backward(
tensor_shape: Shape = None,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
):
"""Receive tensor from next rank in pipeline (backward receive)."""
if parallel_state.is_pipeline_last_stage():
return None
if timers is not None:
timers("backward-recv").start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("backward-recv").stop()
return output_tensor_grad
def send_forward(
output_tensor: torch.Tensor,
override_scatter_gather_tensors_in_pipeline: bool = False,
tensor_shape: Shape = None,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> None:
"""Send tensor to next rank in pipeline (forward send)."""
if parallel_state.is_pipeline_last_stage():
return
if timers is not None:
timers("forward-send").start()
_communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=False,
override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("forward-send").stop()
def send_backward(
input_tensor_grad: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> None:
"""Send tensor to previous rank in pipeline (backward send)."""
if parallel_state.is_pipeline_first_stage():
return
if timers is not None:
timers("backward-send").start()
_communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("backward-send").stop()
def send_forward_recv_backward(
output_tensor: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> None:
"""Batched send and recv with next rank in pipeline."""
if parallel_state.is_pipeline_last_stage():
return None
if timers is not None:
timers("forward-send-backward-recv").start()
_, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("forward-send-backward-recv").stop()
return output_tensor_grad
def send_backward_recv_forward(
input_tensor_grad: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> torch.Tensor:
"""Batched send and recv with previous rank in pipeline."""
if parallel_state.is_pipeline_first_stage():
return None
if timers is not None:
timers("backward-send-forward-recv").start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=True,
recv_next=False,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("backward-send-forward-recv").stop()
return input_tensor
def send_forward_recv_forward(
output_tensor: torch.Tensor,
recv_prev: bool,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> torch.Tensor:
"""Batched recv from previous rank and send to next rank in pipeline."""
if timers is not None:
timers("forward-send-forward-recv").start()
input_tensor, _ = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=False,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("forward-send-forward-recv").stop()
return input_tensor
def send_backward_recv_backward(
input_tensor_grad: torch.Tensor,
recv_next: bool,
tensor_shape: Shape,
*,
dtype: torch.dtype = torch.float,
timers: _Timers = None,
) -> torch.Tensor:
"""Batched recv from next rank and send to previous rank in pipeline."""
if timers is not None:
timers("backward-send-backward-recv").start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("backward-send-backward-recv").stop()
return output_tensor_grad
def send_forward_backward_recv_forward_backward(
output_tensor: torch.Tensor,
input_tensor_grad: torch.Tensor,
recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
):
"""Batched send and recv with previous and next ranks in pipeline."""
if timers is not None:
timers("forward-backward-send-forward-backward-recv").start()
input_tensor, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("forward-backward-send-forward-backward-recv").stop()
return input_tensor, output_tensor_grad
import warnings
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import _forward_backward_pipelining_with_interleaving
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import (
forward_backward_pipelining_without_interleaving,
)
class ExperimentalWarning(Warning):
pass
def get_forward_backward_func(
virtual_pipeline_model_parallel_size, pipeline_model_parallel_size,
):
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if virtual_pipeline_model_parallel_size is not None:
if get_num_microbatches() % pipeline_model_parallel_size != 0:
msg = "number of microbatches is not divisible by pipeline-parallel size when using interleaved schedule"
raise RuntimeError(msg)
warnings.warn(
"Pipeline Model Parallel with interleaving scheduling is experimental. "
f"To use Pipeline Parallel without interleaving, set `virtual_pipeline_model_parallel_size` to `None`: {virtual_pipeline_model_parallel_size}",
ExperimentalWarning
)
forward_backward_func = _forward_backward_pipelining_with_interleaving
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
return forward_backward_func
__all__ = [
"get_forward_backward_func",
]
# NOTE (mkozuki): For simplicity, tentatively `timers` related operations are commented out.
from typing import Any, Callable, Dict, List, Tuple, Union, Optional
import torch
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import unwrap_model
from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes
Batch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]
LossFunc = Callable[[torch.Tensor], torch.Tensor]
FwdStepFunc = Callable[[Batch, torch.nn.Module], Tuple[torch.Tensor, LossFunc]]
def build_model(
model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module],
wrap_with_ddp: bool = True,
virtual_pipeline_model_parallel_size: Optional[int] = None,
*args,
**kwargs
) -> List[torch.nn.Module]:
"""Build the model satisfying pipeline model parallel requirements.
This function sets `pre_process` and `post_process` to `**kwargs` and pass `*args` and `**kwargs` to
`model_provider_func`.
Args:
model_provider_func: A function which takes `*args` and `**kwargs` and returns a `nn.Module`.
wrap_with_ddp: If :obj:`True`, wrap the instantiated model
with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`.
virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel.
*args: arguments for model provider func
**kwargs: Keyword arguments for model provider func
Returns:
a list of `nn.Module`(s). If `virtual_pipeline_model_parallel_size` is not None,
the list has multiple models, otherwise one.
"""
if (
parallel_state.get_pipeline_model_parallel_world_size() > 1 and
virtual_pipeline_model_parallel_size is not None
):
model = []
for i in range(virtual_pipeline_model_parallel_size):
cur_args = args
cur_kwargs = kwargs
parallel_state.set_virtual_pipeline_model_parallel_rank(i)
# Set pre_process and post_process only after virtual rank is set.
pre_process = parallel_state.is_pipeline_first_stage()
post_process = parallel_state.is_pipeline_last_stage()
cur_kwargs.update({
"pre_process": pre_process,
"post_process": post_process,
})
this_model = model_provider_func(*cur_args, **cur_kwargs)
model.append(this_model)
else:
cur_args = args
cur_kwargs = kwargs
pre_process = parallel_state.is_pipeline_first_stage()
post_process = parallel_state.is_pipeline_last_stage()
cur_kwargs.update({
"pre_process": pre_process,
"post_process": post_process,
})
model = model_provider_func(*cur_args, **cur_kwargs)
if not isinstance(model, list):
model = [model]
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
for model_module in model:
for param in model_module.parameters():
set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters.
if parallel_state.get_data_parallel_rank() == 0:
msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format(
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_pipeline_model_parallel_rank(),
sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model])
)
print(msg, flush=True)
# GPU allocation.
for model_module in model:
model_module.cuda(torch.cuda.current_device())
if wrap_with_ddp:
i = torch.cuda.current_device()
model = [
torch.nn.parallel.distributed.DistributedDataParallel(
model_module,
device_ids=[i],
output_device=i,
process_group=parallel_state.get_data_parallel_group(),
)
for model_module in model
]
return model
def _get_params_for_weight_decay_optimization(
model: Union[torch.nn.Module, List[torch.nn.Module]],
) -> Dict[str, torch.nn.Parameter]:
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and biases will have no weight decay but the rest will.
"""
modules = listify_model(model)
from apex.normalization.fused_layer_norm import FusedLayerNorm # NOQA
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module in modules:
for module_ in module.modules():
if isinstance(module_, FusedLayerNorm):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None])
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params
def forward_step(
forward_step_func: FwdStepFunc,
batch: Batch,
model: torch.nn.Module,
input_tensor: Optional[torch.Tensor],
losses_reduced: List[torch.Tensor],
):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor.
Args:
forward_step_func: Model specific function. This takes a minibatch and model as its arguments and
returns the model's output and the loss function.
batch: minibatch
model: unwrappable model
input_tensor:
losses_reduced:
Returns:
output_tensor
"""
# timers = get_timers()
# timers("forward-compute").start()
unwrapped_model = unwrap_model(model)
# NOTE (mkozuki): The passed `model` is expected to implement `set_input_tensor`.
# See https://github.com/NVIDIA/Megatron-LM/blob/5ac5571ba0265af4c491ee0af1508ca7589450c6/megatron/model/transformer.py#L679 # NOQA
# for the details of `set_input_tensor`.
unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(batch, model)
# print(f"forward_step| pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()} is_pipeline_last_stage?: {parallel_state.is_pipeline_last_stage()}")
if parallel_state.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
# timers("forward-compute").stop()
return output_tensor
def backward_step(
input_tensor: Optional[torch.Tensor],
output_tensor: torch.Tensor,
output_tensor_grad: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage).
Args:
input_tensor:
output_tensor:
output_tensor_grad:
Returns:
input_tensor_grad
"""
# timers = get_timers()
# timers("backward-compute").start()
# Retain the grad on the input_tensor.
# if parallel_state.get_pipeline_model_parallel_rank() == 0:
# print(f"{input_tensor}, {output_tensor}, {output_tensor_grad}")
if input_tensor is not None:
input_tensor.retain_grad()
# Backward pass.
# if output_tensor_grad is None:
# output_tensor = optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
input_tensor_grad = None
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
# timers("backward-compute").stop()
return input_tensor_grad
from contextlib import contextmanager
from typing import List, Union
import torch
from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.schedules.common import Batch, FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.log_util import get_transformer_logger
_all__ = ["forward_backward_no_pipelining"]
_logger = get_transformer_logger(__name__)
@contextmanager
def placeholder_handler():
try:
yield
finally:
pass
def forward_backward_no_pipelining(
forward_step_func: FwdStepFunc,
batch: Batch,
model: Union[torch.nn.Module, List[torch.nn.Module]],
*,
forward_only: bool,
**kwargs,
):
"""Run forward and backward passes with no pipeline parallelism (no inter-stage communication).
This pipeline parallel scheduling handles the last microbatch differently to synchronize gradients.
Args:
forward_step_func: A function which takes a minibatch and model as its arguments and
returns model's forward output and the loss function.
The loss function is supposed to take one `torch.Tensor` and
return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
batch: A List of torch.Tensors
model: A `torch.nn.Module` or a list of `torch.nn.Module`.
Keyword args:
forward_only:
**kwargs: Added to handle `tensor_shape` which has no effect on this function.
Returns:
a list of dictionaries of loss `torch.Tensor`s if the last stage, empty list otherwise.
"""
model = listify_model(model)
if len(model) != 1:
msg = f"`model` is expected be a `nn.Module`, but {type(model)}"
raise RuntimeError(msg)
model = model[0]
context_handler = placeholder_handler
if isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel):
context_handler = model.no_sync
losses_reduced = []
input_tensor, output_tensor_grad = None, None
num_micro_batches = get_num_microbatches()
with context_handler():
for i in range(num_micro_batches - 1):
_logger.info(f"Iter {i} of {num_micro_batches - 1}")
cur_micro_batch = get_kth_microbatch(batch, i)
_logger.debug("Call `forward_step`")
output_tensor = forward_step(
forward_step_func, cur_micro_batch, model, input_tensor, losses_reduced)
if not forward_only:
_logger.debug("Call `backward_step`")
backward_step(input_tensor, output_tensor, output_tensor_grad)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
_logger.info("Cooldown")
_logger.debug("Call `forward_step`")
output_tensor = forward_step(
forward_step_func, get_kth_microbatch(batch, num_micro_batches - 1), model, input_tensor, losses_reduced
)
if not forward_only:
_logger.debug("Call `backward_step`")
backward_step(input_tensor, output_tensor, output_tensor_grad)
return losses_reduced
This diff is collapsed.
# coding=utf-8 # coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,11 +15,11 @@ ...@@ -15,11 +15,11 @@
"""Model parallel utility interface.""" """Model parallel utility interface."""
from .cross_entropy import vocab_parallel_cross_entropy from apex.transformer.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data from apex.transformer.tensor_parallel.data import broadcast_data
from .layers import ( from apex.transformer.tensor_parallel.layers import (
ColumnParallelLinear, ColumnParallelLinear,
RowParallelLinear, RowParallelLinear,
VocabParallelEmbedding, VocabParallelEmbedding,
...@@ -28,7 +28,7 @@ from .layers import ( ...@@ -28,7 +28,7 @@ from .layers import (
copy_tensor_model_parallel_attributes, copy_tensor_model_parallel_attributes,
) )
from .mappings import ( from apex.transformer.tensor_parallel.mappings import (
copy_to_tensor_model_parallel_region, copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region, gather_from_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region,
...@@ -41,11 +41,9 @@ from .random import ( ...@@ -41,11 +41,9 @@ from .random import (
init_checkpointed_activations_memory_buffer, init_checkpointed_activations_memory_buffer,
model_parallel_cuda_manual_seed, model_parallel_cuda_manual_seed,
reset_checkpointed_activations_memory_buffer, reset_checkpointed_activations_memory_buffer,
gather_split_1d_tensor,
split_tensor_into_1d_equal_chunks,
) )
from .utils import divide, split_tensor_along_last_dim from apex.transformer.tensor_parallel.utils import split_tensor_along_last_dim
__all__ = [ __all__ = [
...@@ -71,9 +69,6 @@ __all__ = [ ...@@ -71,9 +69,6 @@ __all__ = [
"init_checkpointed_activations_memory_buffer", "init_checkpointed_activations_memory_buffer",
"model_parallel_cuda_manual_seed", "model_parallel_cuda_manual_seed",
"reset_checkpointed_activations_memory_buffer", "reset_checkpointed_activations_memory_buffer",
"gather_split_1d_tensor",
"split_tensor_into_1d_equal_chunks",
# utils.py # utils.py
"divide",
"split_tensor_along_last_dim", "split_tensor_along_last_dim",
] ]
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment