Unverified Commit 96850dfa authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #80 from ROCmSoftwarePlatform/IFU-master-2022-07-29

IFU-master-2022-07-29
parents 87fc4125 cc5f83b5
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# NOTE(mkozuki): This file defines two LayerNorm that are compatible with Megatron-LM.
# while avoiding introducing the breaking change of `"sequence_parallel_enabled"` attribute into apex.normalization.FusedLayerNorm
# and apex.contrib.layer_norm.FastLayerNorm.
import warnings
import torch
from apex.normalization import FusedLayerNorm as OrigFusedLayerNorm
from apex.normalization import MixedFusedLayerNorm as OrigMixedFusedLayerNorm
try:
from apex.contrib.layer_norm import FastLayerNorm as OrigFastLayerNorm
except ImportError:
HAS_FAST_LAYER_NORM = False
else:
HAS_FAST_LAYER_NORM = True
__all__ = [
"FusedLayerNorm",
"FastLayerNorm",
"MixedFusedLayerNorm",
]
def _set_sequence_parallel_enabled(
param: torch.Tensor,
sequence_parallel_enabled: bool,
) -> None:
setattr(param, "sequence_parallel_enabled", sequence_parallel_enabled)
class FusedLayerNorm(OrigFusedLayerNorm):
def __init__(
self,
normalized_shape,
eps: float = 1e-5,
elementwise_affine: bool = True,
*,
sequence_parallel_enabled: bool = False,
):
super().__init__(
normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
)
self.sequence_parallel_enabled = sequence_parallel_enabled
if self.elementwise_affine:
_set_sequence_parallel_enabled(self.weight, self.sequence_parallel_enabled)
_set_sequence_parallel_enabled(self.bias, self.sequence_parallel_enabled)
# note: MixedFusedLayerNorm is no different from FusedLayerNorm if it's used in `torch.cuda.amp`.
class MixedFusedLayerNorm(OrigMixedFusedLayerNorm):
def __init__(
self,
normalized_shape,
eps: float = 1e-5,
**kwargs,
) -> None:
self.sequence_parallel_enabled = kwargs.get("sequence_parallel_enabled", False)
super().__init__(normalized_shape=normalized_shape, eps=eps, **kwargs)
if self.sequence_parallel_enabled:
_set_sequence_parallel_enabled(self.weight, self.sequence_parallel_enabled)
_set_sequence_parallel_enabled(self.bias, self.sequence_parallel_enabled)
if HAS_FAST_LAYER_NORM:
class FastLayerNorm(OrigFastLayerNorm):
def __init__(
self,
hidden_size,
eps: float = 1e-5,
*,
sequence_parallel_enabled: bool = False,
):
super().__init__(
hidden_size=hidden_size,
eps=eps
)
self.sequence_parallel_enabled = sequence_parallel_enabled
_set_sequence_parallel_enabled(self.weight, self.sequence_parallel_enabled)
_set_sequence_parallel_enabled(self.bias, self.sequence_parallel_enabled)
else:
class FastLayerNorm(FusedLayerNorm):
def __init__(
self,
hidden_size,
eps: float = 1e-5,
*,
sequence_parallel_enabled: bool = False,
):
warnings.warn("`apex.contrib.layer_norm.FastLayerNorm` isn't available thus falling back to `apex.normalization.FusedLayerNorm`")
super().__init__(
normalized_shape=hidden_size,
eps=eps,
elementwise_affine=True,
sequence_parallel_enabled=sequence_parallel_enabled,
)
from typing import Optional
import logging
import os
import threading
def get_transformer_logger(name: str) -> logging.Logger:
......@@ -16,4 +14,5 @@ def set_logging_level(verbosity) -> None:
verbosity
"""
from apex import _library_root_logger
_library_root_logger.setLevel(verbosity)
......@@ -17,13 +17,18 @@ from abc import ABC
from abc import abstractmethod
from typing import Optional, List
from apex.transformer.log_util import get_transformer_logger
_logger = get_transformer_logger(__name__)
def build_num_microbatches_calculator(
rank: int,
rampup_batch_size: Optional[List[int]],
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
rank: int,
rampup_batch_size: Optional[List[int]],
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
):
# Constant num micro-batches.
if rampup_batch_size is None:
......@@ -31,8 +36,10 @@ def build_num_microbatches_calculator(
global_batch_size, micro_batch_size, data_parallel_size
)
if rank == 0:
print(
"setting number of micro-batches to constant {}".format(num_microbatches_calculator.get()), flush=True
_logger.info(
"setting number of micro-batches to constant {}".format(
num_microbatches_calculator.get()
)
)
else:
......@@ -45,13 +52,15 @@ def build_num_microbatches_calculator(
batch_size_increment = int(rampup_batch_size[1])
ramup_samples = int(rampup_batch_size[2])
if rank == 0:
print(
_logger.info(
"will use batch size rampup starting from global batch "
"size {} to global batch size {} with batch size increments "
"{} over {} samples.".format(
start_batch_size, global_batch_size, batch_size_increment, ramup_samples
),
flush=True,
start_batch_size,
global_batch_size,
batch_size_increment,
ramup_samples,
)
)
num_microbatches_calculator = RampupBatchsizeNumMicroBatches(
start_batch_size,
......@@ -86,7 +95,9 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator):
micro_batch_times_data_parallel = micro_batch_size * data_parallel_size
assert global_batch_size % micro_batch_times_data_parallel == 0, (
"global batch size ({}) is not divisible by micro batch size ({})"
" times data parallel size ({})".format(global_batch_size, micro_batch_size, data_parallel_size)
" times data parallel size ({})".format(
global_batch_size, micro_batch_size, data_parallel_size
)
)
self.num_micro_batches = global_batch_size // micro_batch_times_data_parallel
assert self.num_micro_batches >= 1
......@@ -126,7 +137,9 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
self.micro_batch_size = micro_batch_size
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * self.data_parallel_size
self.micro_batch_times_data_parallel_size = (
self.micro_batch_size * self.data_parallel_size
)
assert self.micro_batch_times_data_parallel_size > 0
assert start_batch_size > 0
......@@ -158,15 +171,25 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
self.current_global_batch_size = self.global_batch_size
else:
steps = int(consumed_samples / self.rampup_samples_per_increment)
self.current_global_batch_size = self.start_batch_size + steps * self.batch_size_increment
self.current_global_batch_size = (
self.start_batch_size + steps * self.batch_size_increment
)
assert self.current_global_batch_size <= self.global_batch_size
if consistency_check:
assert self.current_global_batch_size % self.micro_batch_times_data_parallel_size == 0, (
assert (
self.current_global_batch_size
% self.micro_batch_times_data_parallel_size
== 0
), (
"current global "
"batch size ({}) is not divisible by micro-batch-size ({}) times"
"data parallel size ({})".format(
self.current_global_batch_size, self.micro_batch_size, self.data_parallel_size
self.current_global_batch_size,
self.micro_batch_size,
self.data_parallel_size,
)
)
self.num_micro_batches = self.current_global_batch_size // self.micro_batch_times_data_parallel_size
self.num_micro_batches = (
self.current_global_batch_size // self.micro_batch_times_data_parallel_size
)
......@@ -12,14 +12,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO (mkozuki): Replace assert with RuntimeError.
# TODO (mkozuki): Sort the functions in the same order of megatron/mpu/initialize.py
"""Model and data parallel groups."""
from typing import Tuple
from typing import Tuple, Optional
import warnings
import torch
# TODO (mkozuki): Consider dissecting utils as this utils import is here
# only for ensure_divisibility
from apex.transformer.utils import ensure_divisibility
from apex.transformer.log_util import get_transformer_logger
_logger = get_transformer_logger(__name__)
# N.B. (mkozuki): Diff btwn Megatron-LM & apex parallel_state
# set(megatron_mpu_initialize_funcs) - set(apex.transformer.parallel_state) =
# {
# 'get_num_layers',
# }
# Intra-layer model parallel group that the current rank belongs to.
......@@ -30,11 +40,17 @@ _PIPELINE_MODEL_PARALLEL_GROUP = None
_MODEL_PARALLEL_GROUP = None
# Embedding group.
_EMBEDDING_GROUP = None
# Position embedding group.
_POSITION_EMBEDDING_GROUP = None
# Relative position embedding group.
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP = None
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
......@@ -45,6 +61,13 @@ _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS = None
# A list of ranks that have a copy of the relative position embedding.
_ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS = None
_DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS = None
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage
_PIPELINE_GLOBAL_RANKS = None
......@@ -56,14 +79,31 @@ def is_unitialized():
def initialize_model_parallel(
tensor_model_parallel_size_=1, pipeline_model_parallel_size_=1, virtual_pipeline_model_parallel_size_=None
):
tensor_model_parallel_size_: int = 1,
pipeline_model_parallel_size_: int = 1,
virtual_pipeline_model_parallel_size_: Optional[int] = None,
pipeline_model_parallel_split_rank_: Optional[int] = None,
*,
default_backend: Optional[str] = None,
p2p_backend: Optional[str] = None,
) -> None:
"""
Initialize model data parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used to parallelize model tensor.
pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved pipeline).
pipeline_model_parallel_split_rank: for models with both encoder and decoder, rank in pipeline with split point.
Keyword Arguments:
default_backend: Backend of process groups except for pipeline parallel ones.
If :obj:`None`, the backend specified in `torch.distributed.init_process_group` will be used.
p2p_backend: Backend of process groups for pipeline model parallel.
If :obj:`None`, the backend specified in `torch.distributed.init_process_group` will be used.
.. note::
`torch_ucc <https://github.com/facebookresearch/torch_ucc>`_ is
necessary for "ucc" backend.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
......@@ -83,28 +123,61 @@ def initialize_model_parallel(
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)
ensure_divisibility(world_size, tensor_model_parallel_size * pipeline_model_parallel_size)
data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)
assert default_backend is None or default_backend in ("nccl", "ucc")
assert p2p_backend is None or p2p_backend in ("nccl", "ucc")
if "ucc" in (default_backend, p2p_backend):
check_torch_ucc_availability()
warnings.warn("`ucc` backend support is experimental", ExperimentalWarning)
if default_backend == "ucc":
warnings.warn("The UCC's functionality as `default_backend` is not well verified", ExperimentalWarning)
world_size: int = torch.distributed.get_world_size()
tensor_model_parallel_size: int = min(tensor_model_parallel_size_, world_size)
pipeline_model_parallel_size: int = min(pipeline_model_parallel_size_, world_size)
if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:
raise RuntimeError(
f"`world_size` ({world_size}) is not divisible by tensor_model_parallel_size ({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})"
)
data_parallel_size: int = world_size // (
tensor_model_parallel_size * pipeline_model_parallel_size
)
if torch.distributed.get_rank() == 0:
print("> initializing tensor model parallel with size {}".format(tensor_model_parallel_size))
print("> initializing pipeline model parallel with size {}".format(pipeline_model_parallel_size))
print("> initializing data parallel with size {}".format(data_parallel_size))
_logger.info(
"> initializing tensor model parallel with size {}".format(
tensor_model_parallel_size
)
)
_logger.info(
"> initializing pipeline model parallel with size {}".format(
pipeline_model_parallel_size
)
)
_logger.info(
"> initializing data parallel with size {}".format(data_parallel_size)
)
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
num_data_parallel_groups = world_size // data_parallel_size
num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
num_data_parallel_groups: int = world_size // data_parallel_size
if virtual_pipeline_model_parallel_size_ is not None:
assert pipeline_model_parallel_size_ > 2, \
'pipeline-model-parallel size should be greater than 2 with ' \
'interleaved schedule'
# n.b. (eqy) This check was inherited from Megatron-LM, need to revisit
# the root cause as we do see numerical mismatches with 2 stages and
# the interleaved schedule
assert pipeline_model_parallel_size_ > 2, (
"pipeline-model-parallel size should be greater than 2 with "
"interleaved schedule"
)
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = (
virtual_pipeline_model_parallel_size_
)
if pipeline_model_parallel_split_rank_ is not None:
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank_
rank = torch.distributed.get_rank()
......@@ -118,7 +191,7 @@ def initialize_model_parallel(
for j in range(tensor_model_parallel_size):
ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
all_data_parallel_group_ranks.append(list(ranks))
group = torch.distributed.new_group(ranks)
group = torch.distributed.new_group(ranks, backend=default_backend)
if rank in ranks:
_DATA_PARALLEL_GROUP = group
......@@ -126,17 +199,24 @@ def initialize_model_parallel(
global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized"
for i in range(data_parallel_size):
ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks]
group = torch.distributed.new_group(ranks)
ranks = [
data_parallel_group_ranks[i]
for data_parallel_group_ranks in all_data_parallel_group_ranks
]
group = torch.distributed.new_group(ranks, backend=default_backend)
if rank in ranks:
_MODEL_PARALLEL_GROUP = group
# Build the tensor model-parallel groups.
global _TENSOR_MODEL_PARALLEL_GROUP
assert _TENSOR_MODEL_PARALLEL_GROUP is None, "tensor model parallel group is already initialized"
assert (
_TENSOR_MODEL_PARALLEL_GROUP is None
), "tensor model parallel group is already initialized"
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
group = torch.distributed.new_group(ranks)
ranks = list(
range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
)
group = torch.distributed.new_group(ranks, backend=default_backend)
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group
......@@ -144,43 +224,111 @@ def initialize_model_parallel(
# (first and last rank in each pipeline model-parallel group).
global _PIPELINE_MODEL_PARALLEL_GROUP
global _PIPELINE_GLOBAL_RANKS
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, "pipeline model parallel group is already initialized"
assert (
_PIPELINE_MODEL_PARALLEL_GROUP is None
), "pipeline model parallel group is already initialized"
global _EMBEDDING_GROUP
global _EMBEDDING_GLOBAL_RANKS
assert _EMBEDDING_GROUP is None, "embedding group is already initialized"
global _POSITION_EMBEDDING_GROUP
global _POSITION_EMBEDDING_GLOBAL_RANKS
assert (
_POSITION_EMBEDDING_GROUP is None
), "position embedding group is already initialized"
global _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
global _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
global _ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
global _DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
assert _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP is None or \
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP is None, \
'relative position embedding group is already initialized'
for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks)
group = torch.distributed.new_group(ranks, backend=p2p_backend)
if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group
_PIPELINE_GLOBAL_RANKS = ranks
# Setup embedding group (to exchange gradients between
# first and last stages).
encoder_relative_position_embedding_ranks = None
decoder_relative_position_embedding_ranks = None
if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]]
position_embedding_ranks = [ranks[0]]
encoder_relative_position_embedding_ranks = [ranks[0]]
decoder_relative_position_embedding_ranks = [ranks[0]]
if pipeline_model_parallel_split_rank_ is not None:
encoder_relative_position_embedding_ranks = \
ranks[:pipeline_model_parallel_split_rank_]
decoder_relative_position_embedding_ranks = \
ranks[pipeline_model_parallel_split_rank_:]
if ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks:
embedding_ranks = [
ranks[0],
ranks[pipeline_model_parallel_split_rank_],
ranks[-1],
]
if (
ranks[pipeline_model_parallel_split_rank_]
not in position_embedding_ranks
):
position_embedding_ranks = [
ranks[0],
ranks[pipeline_model_parallel_split_rank_],
]
else:
embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks)
position_embedding_ranks = ranks
encoder_relative_position_embedding_ranks = ranks
decoder_relative_position_embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks, backend=default_backend)
if rank in embedding_ranks:
_EMBEDDING_GROUP = group
if rank in ranks:
_EMBEDDING_GLOBAL_RANKS = embedding_ranks
group = torch.distributed.new_group(position_embedding_ranks, backend=default_backend)
if rank in position_embedding_ranks:
_POSITION_EMBEDDING_GROUP = group
if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
if encoder_relative_position_embedding_ranks:
group = torch.distributed.new_group(encoder_relative_position_embedding_ranks)
if rank in encoder_relative_position_embedding_ranks:
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP = group
if rank in ranks:
_ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS = \
encoder_relative_position_embedding_ranks
if decoder_relative_position_embedding_ranks:
group = torch.distributed.new_group(decoder_relative_position_embedding_ranks)
if rank in decoder_relative_position_embedding_ranks:
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP = group
if rank in ranks:
_DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS = \
decoder_relative_position_embedding_ranks
def get_rank_info() -> Tuple[int, int, int]:
"""Returns a tuple of (tensor, pipeline, data)-parallel-rank for logger."""
"""Returns a tuple of (data, tensor, pipeline, virtual pipeline)-parallel-rank for logger."""
if model_parallel_is_initialized():
return (
get_data_parallel_rank(),
get_tensor_model_parallel_rank(),
get_pipeline_model_parallel_rank(),
# get_virtual_pipeline_model_parallel_rank(),
get_data_parallel_rank(),
get_virtual_pipeline_model_parallel_rank(),
)
return (0, 0, 0)
return (0, 0, 0, 0)
def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or _PIPELINE_MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None:
if (
_TENSOR_MODEL_PARALLEL_GROUP is None
or _PIPELINE_MODEL_PARALLEL_GROUP is None
or _DATA_PARALLEL_GROUP is None
):
return False
return True
......@@ -193,13 +341,17 @@ def get_model_parallel_group():
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, "intra_layer_model parallel group is not initialized"
assert (
_TENSOR_MODEL_PARALLEL_GROUP is not None
), "intra_layer_model parallel group is not initialized"
return _TENSOR_MODEL_PARALLEL_GROUP
def get_pipeline_model_parallel_group():
"""Get the pipeline model parallel group the caller rank belongs to."""
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, "pipeline_model parallel group is not initialized"
assert (
_PIPELINE_MODEL_PARALLEL_GROUP is not None
), "pipeline_model parallel group is not initialized"
return _PIPELINE_MODEL_PARALLEL_GROUP
......@@ -215,6 +367,25 @@ def get_embedding_group():
return _EMBEDDING_GROUP
def get_position_embedding_group():
"""Get the position embedding group the caller rank belongs to."""
assert (
_POSITION_EMBEDDING_GROUP is not None
), "position embedding group is not initialized"
return _POSITION_EMBEDDING_GROUP
def get_encoder_relative_position_embedding_group():
"""Get the encoder relative position embedding group the caller rank belongs to."""
assert _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP is not None, \
'encoder relative position embedding group is not initialized'
return _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
def get_decoder_relative_position_embedding_group():
"""Get the decoder relative position embedding group the caller rank belongs to."""
assert _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP is not None, \
'decoder relative position embedding group is not initialized'
return _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
def is_rank_in_embedding_group(ignore_virtual=False):
"""Return true if current rank is in embedding group, False otherwise."""
rank = torch.distributed.get_rank()
......@@ -231,6 +402,64 @@ def is_rank_in_embedding_group(ignore_virtual=False):
return False
def is_rank_in_position_embedding_group():
"""Return whether the current rank is in position embedding group."""
rank = torch.distributed.get_rank()
global _POSITION_EMBEDDING_GLOBAL_RANKS
return rank in _POSITION_EMBEDDING_GLOBAL_RANKS
def is_rank_in_encoder_relative_position_embedding_group():
"""Return true if current rank is in encoder relative position embedding group, False otherwise."""
rank = torch.distributed.get_rank()
global _ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
return rank in _ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
def is_rank_in_decoder_relative_position_embedding_group():
"""Return true if current rank is in decoder relative position embedding group, False otherwise."""
rank = torch.distributed.get_rank()
global _DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
return rank in _DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
def is_pipeline_stage_before_split(rank=None):
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_after_split(rank=None):
"""Return True if pipeline stage executes decoder block for a model
with both encoder and decoder."""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_at_split():
"""Return true if pipeline stage executes decoder block and next
stage executes encoder block for a model with both encoder and
decoder."""
rank = get_pipeline_model_parallel_rank()
return is_pipeline_stage_before_split(rank) and is_pipeline_stage_after_split(
rank + 1
)
def set_tensor_model_parallel_world_size(world_size):
"""Set the tensor model parallel size"""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
......@@ -287,6 +516,21 @@ def get_pipeline_model_parallel_rank():
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
# TODO (mkozuki): Add [`get_num_layers`](https://github.com/NVIDIA/Megatron-LM/blob/e156d2fea7fc5c98e645f7742eb86b643956d840/megatron/mpu/initialize.py#L321) here, maybe?
def get_pipeline_model_parallel_split_rank():
"""Return my rank for the pipeline model parallel split rank."""
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
def set_pipeline_model_parallel_split_rank(pipeline_model_parallel_split_rank: int):
"""Set my rank for the pipeline model parallel split rank."""
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank
def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
......@@ -301,12 +545,16 @@ def is_pipeline_first_stage(ignore_virtual=False):
def is_pipeline_last_stage(ignore_virtual=False):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
virtual_pipeline_model_parallel_world_size = get_virtual_pipeline_model_parallel_world_size()
virtual_pipeline_model_parallel_world_size = (
get_virtual_pipeline_model_parallel_world_size()
)
if virtual_pipeline_model_parallel_world_size is not None and get_virtual_pipeline_model_parallel_rank() != (
virtual_pipeline_model_parallel_world_size - 1
):
return False
return get_pipeline_model_parallel_rank() == (get_pipeline_model_parallel_world_size() - 1)
return get_pipeline_model_parallel_rank() == (
get_pipeline_model_parallel_world_size() - 1
)
def get_virtual_pipeline_model_parallel_rank():
......@@ -335,26 +583,42 @@ def get_tensor_model_parallel_src_rank():
return (global_rank // local_world_size) * local_world_size
def get_data_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank in the data parallel group."""
global_rank = torch.distributed.get_rank()
data_parallel_size: int = get_data_parallel_world_size()
num_data_parallel_groups = torch.distributed.get_world_size() // data_parallel_size
return global_rank % num_data_parallel_groups
def get_pipeline_model_parallel_first_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
assert (
_PIPELINE_GLOBAL_RANKS is not None
), "Pipeline parallel group is not initialized"
return _PIPELINE_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_last_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
assert (
_PIPELINE_GLOBAL_RANKS is not None
), "Pipeline parallel group is not initialized"
last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_next_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
assert (
_PIPELINE_GLOBAL_RANKS is not None
), "Pipeline parallel group is not initialized"
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
def get_pipeline_model_parallel_prev_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
assert (
_PIPELINE_GLOBAL_RANKS is not None
), "Pipeline parallel group is not initialized"
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
......@@ -370,6 +634,9 @@ def get_data_parallel_rank():
return torch.distributed.get_rank(group=get_data_parallel_group())
# note (mkozuki): `destroy_model_parallel` voids more global variables than Megatron-LM.
# Otherwise pipeline parallel forward_backward functions test hangs possibly because
# the clean-up of the original is NOT enough.
def destroy_model_parallel():
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
......@@ -382,6 +649,12 @@ def destroy_model_parallel():
_DATA_PARALLEL_GROUP = None
global _EMBEDDING_GROUP
_EMBEDDING_GROUP = None
global _POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP = None
global _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP = None
global _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
......@@ -394,3 +667,16 @@ def destroy_model_parallel():
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# Used to warn when the UCC is specified.
class ExperimentalWarning(Warning): pass
def check_torch_ucc_availability() -> None:
try:
import torch_ucc # NOQA
except ImportError:
raise ImportError(
"UCC backend requires [torch_ucc](https://github.com/facebookresearch/torch_ucc) but not found"
)
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,63 +12,108 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(mkozuki): Consider removing `timers`.
from functools import reduce
import operator
from typing import Union, Optional, Tuple
import warnings
import torch
from apex._autocast_utils import _get_current_dtype
from apex.transformer import parallel_state
from apex.transformer.log_util import get_transformer_logger
from apex.transformer.utils import split_tensor_into_1d_equal_chunks
from apex.transformer.utils import gather_split_1d_tensor
from apex.transformer.pipeline_parallel.utils import Shape
from apex.transformer.pipeline_parallel._timers import _Timers
_logger = get_transformer_logger(__name__)
class FutureTensor:
def __init__(self, tensor: torch.Tensor, waitfunc):
self.tensor = tensor
self.waitfunc = waitfunc
def get(self):
if self.waitfunc is not None:
res = self.waitfunc()
if isinstance(res, torch.Tensor):
self.tensor = res
self.waitfunc = None
return self.tensor
def _run_p2pops(
tensor_send_prev: Union[torch.Tensor, None],
tensor_send_next: Union[torch.Tensor, None],
tensor_recv_prev: Union[torch.Tensor, None],
tensor_recv_next: Union[torch.Tensor, None],
tensor_send_prev: Union[torch.Tensor, None],
tensor_send_next: Union[torch.Tensor, None],
tensor_recv_prev: Union[torch.Tensor, None],
tensor_recv_next: Union[torch.Tensor, None],
async_comm: bool = False
):
ops = []
p2p_group = parallel_state.get_pipeline_model_parallel_group()
default_group = parallel_state.get_model_parallel_group()
need_to_sync = p2p_group.name() != default_group.name()
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor_send_prev,
parallel_state.get_pipeline_model_parallel_prev_rank(),
op=torch.distributed.isend,
tensor=tensor_send_prev,
peer=parallel_state.get_pipeline_model_parallel_prev_rank(),
group=p2p_group,
)
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_prev,
parallel_state.get_pipeline_model_parallel_prev_rank(),
op=torch.distributed.irecv,
tensor=tensor_recv_prev,
peer=parallel_state.get_pipeline_model_parallel_prev_rank(),
group=p2p_group,
)
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor_send_next,
parallel_state.get_pipeline_model_parallel_next_rank(),
op=torch.distributed.isend,
tensor=tensor_send_next,
peer=parallel_state.get_pipeline_model_parallel_next_rank(),
group=p2p_group,
)
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_next,
parallel_state.get_pipeline_model_parallel_next_rank(),
op=torch.distributed.irecv,
tensor=tensor_recv_next,
peer=parallel_state.get_pipeline_model_parallel_next_rank(),
group=p2p_group,
)
ops.append(recv_next_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
if need_to_sync:
torch.cuda.synchronize()
reqs = torch.distributed.batch_isend_irecv(ops)
if async_comm:
assert len(reqs) == len(ops)
tensor_send_prev_req = None if tensor_send_prev is None else reqs.pop(0)
tensor_recv_prev_req = None if tensor_recv_prev is None else reqs.pop(0)
tensor_send_next_req = None if tensor_send_next is None else reqs.pop(0)
tensor_recv_next_req = None if tensor_recv_next is None else reqs.pop(0)
return (tensor_send_prev_req, tensor_recv_prev_req, tensor_send_next_req, tensor_recv_next_req)
else:
for req in reqs:
req.wait()
return (None, None, None, None)
return (None, None, None, None)
# TODO(mkozuki): Check if it's possible to sunset `override_scatter_gather_tensors_in_pipeline`.
# TODO(mkozuki): Think about if it's possible to push some logic and arguments e.g.
# `scatter_gather_tensors_in_pipeline`, `sequence_parallel_enabled`, and
# `override_scatter_gather_tensors_in_pipeline` # to the user of
# apex.transformer forward_backwardfunctions.
def _communicate(
tensor_send_next: Optional[torch.Tensor],
tensor_send_prev: Optional[torch.Tensor],
......@@ -76,14 +121,26 @@ def _communicate(
recv_next: bool,
tensor_shape: Optional[Shape] = None,
override_scatter_gather_tensors_in_pipeline: bool = False,
dtype_: torch.dtype = torch.float,
dtype_: Optional[torch.dtype] = None,
*,
scatter_gather_tensors_in_pipeline: bool = True,
params_dtype: Optional[torch.dtype] = None,
fp32_residual_connection: bool = False,
) -> Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None]]:
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
) -> Tuple[Union[torch.Tensor, FutureTensor, None], Union[torch.Tensor, FutureTensor, None]]:
"""Base function for communication of tensors between stages.
.. note::
Reference https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/cfd2e2160700b7f2c1bf35298ac14bc341f4c759/megatron/p2p_communication.py#L24-L159
dtype logic: If none of ``dtype_``, ``params_dtype``, ``fp32_residual_connection`` is specified,
torch.float32 is used.
See https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/arguments.py#L145-L159
for the details of arguments of ``dtype_``, ``params_dtype``, ``fp32_residual_connection``.
Args:
tensor_send_next: tensor to send to next rank (no tensor sent if set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if set to None).
......@@ -99,6 +156,9 @@ def _communicate(
params_dtype: Optional and legacy. Defaults to torch.float. If you manually call `.half()` or `.bfloat16()` on
your model deliberately, pass this argument.
fp32_residual_connection: Optional. If :obj:`True`, move residual connections to fp32.
sequence_parallel_enabled: Set to :obj:`True` if sequence parallel is enabled.
This argument is here for consistency with Megatron-LM.
This argument has an effect on the communication optimization, not on tensor_shape update.
Returns:
tuple containing
......@@ -106,6 +166,13 @@ def _communicate(
- tensor_recv_prev: `torch.Tensor` if `recv_prev` is :obj:`True`, `None` otherwise.
- tensor_recv_next: `torch.Tensor` if `recv_next` is :obj:`True`, `None` otherwise.
"""
if async_comm and sequence_parallel_enabled:
import warnings # NOQA
class ExperimentalWarning(UserWarning): pass # NOQA
warnings.warn(
"The combination of `async_comm` and `sequence_parallel_enabled` is not well tested.",
ExperimentalWarning,
)
# Create placeholder tensors for receive in forward and backward directions if needed.
tensor_recv_prev = None
tensor_recv_next = None
......@@ -113,25 +180,45 @@ def _communicate(
# In megatron, `tensor_shape` is set to `(args.seq_length, args.micro_batch_size, args.hidden_size)`
raise RuntimeError(
"`tensor_shape` must be specified. Common `tensor_shape` is `(seq_length, micro_batch_size, hidden_size)`")
if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = (reduce(operator.mul, tensor_shape, 1) // parallel_state.get_tensor_model_parallel_world_size(),)
tensor_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
override_scatter_gather_tensors_in_pipeline_ = False
# TODO(mkozuki): Demystify hardcode False of `scatter_gather_tensors_in_pipeline` and add a testcase if possible.
# NOTE(mkozuki): This is super strange and doesn't make sense to me. I have no idea what is happening here.
# However, I can say that this hardcoding override is necessary for sequence parallel in nemo megatron to work.
# I've not managed to reproduce the hang using standalone GPT with sequence parallel.
# The hang in NeMo Megatron happens in the 3rd iteration, the last iteration of stead phase inside
# forward_backward_pipelining_without_interleaving, pipeline parallel rank of 0 (tensor model parallel world
# size of 2 and pipeline model parallel world size of 2). The commit then of APEX and NeMo were
# https://github.com/NVIDIA/apex/pull/1396/commits/3060c98dd8ba42abf7702ea9d2cff0f39ea74f45 and
# https://github.com/NVIDIA/NeMo/pull/4232/commits/1cb32dfca2ab9b20f53ebdb84476c34cb42f0205.
# The PyTorch version was 1.13.0a0+git2d354cd, for what is worth.
# Currently, indiscriminately this is set to `False`, which can lead to an unexpected performance regression
# for non sequence parallel case.
scatter_gather_tensors_in_pipeline = False
if scatter_gather_tensors_in_pipeline and not sequence_parallel_enabled:
tensor_chunk_size = int(reduce(operator.mul, tensor_shape, 1))
if tensor_chunk_size % tensor_parallel_size == 0:
tensor_chunk_shape = [tensor_chunk_size // tensor_parallel_size]
else:
tensor_chunk_shape = tensor_shape
override_scatter_gather_tensors_in_pipeline_ = True
else:
tensor_chunk_shape = tensor_shape
# NOTE(mkozuki): In PyTorch AMP, i.e. `torch.cuda.amp.autocast` context, activation tensors can be either FP32,
# FP16, or BF16 and there's no way to tell the dtypes of tensors on different devices in general.
# It might be possible if we restrict model architecture.
# dtype = params_dtype or torch.float
# if fp32_residual_connection:
# dtype = torch.float
# if dtype_ is not None:
# dtype = dtype_
# requires_grad = False
if dtype_ != torch.float32 or params_dtype is not None:
if torch.distributed.get_rank() == 0:
warnings.warn("Tensor P2P communications are executed in FP32")
dtype = torch.float32
# The dtype logic below is copied from NVIDIA/Megatron-LM repo:
# https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/p2p_communication.py#L74-L81
dtype = params_dtype or torch.float
if fp32_residual_connection:
dtype = torch.float
requires_grad = True
if dtype_ is not None:
dtype = dtype_
# TODO(mkozuki): Figure out why this logic of requires_grad isn't working
# when sequence_parallel_enabled=True. Otherwise, `x.retain_grad()` of
# https://github.com/crcrpar/apex/blob/069832078a652b4bd8a99db84faf953a81415ab3/apex/transformer/pipeline_parallel/schedules/common.py#L360
# fails.
# requires_grad = False
if recv_prev:
tensor_recv_prev = torch.empty(
......@@ -149,7 +236,12 @@ def _communicate(
)
# Split tensor into smaller chunks if using scatter-gather optimization.
if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
scatter_gather_optimization_doable = (
not override_scatter_gather_tensors_in_pipeline_
and scatter_gather_tensors_in_pipeline
and not sequence_parallel_enabled
)
if scatter_gather_optimization_doable:
if tensor_send_next is not None:
tensor_send_next = split_tensor_into_1d_equal_chunks(tensor_send_next)
......@@ -157,41 +249,89 @@ def _communicate(
tensor_send_prev = split_tensor_into_1d_equal_chunks(tensor_send_prev)
# Send tensors in both the forward and backward directions as appropriate.
_run_p2pops(tensor_send_prev, tensor_send_next, tensor_recv_prev, tensor_recv_next)
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
tensor_send_prev_req, tensor_recv_prev_req, tensor_send_next_req, tensor_recv_next_req = _run_p2pops(tensor_send_prev, tensor_send_next, tensor_recv_prev, tensor_recv_next, async_comm=async_comm)
if async_comm:
tensor_recv_prev_waitfunc = None
tensor_recv_next_waitfunc = None
# TODO: investigate whether this is necessary for correctness (ref: https://github.com/pytorch/pytorch/issues/38642)
# see also: sync added for async_comm callbacks below in gather_recv_prev_wait and gather_recv_next_wait
if tensor_recv_prev_req is not None:
def tensor_recv_prev_wait():
tensor_recv_prev_req.wait()
torch.cuda.synchronize()
tensor_recv_prev_waitfunc = tensor_recv_prev_wait
if tensor_recv_next_req is not None:
def tensor_recv_next_wait():
tensor_recv_next_req.wait()
torch.cuda.synchronize()
tensor_recv_next_waitfunc = tensor_recv_next_wait
else:
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
# If using scatter-gather optimization, gather smaller chunks.
if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
if recv_prev:
tensor_recv_prev = (
gather_split_1d_tensor(tensor_recv_prev)
.view(tensor_shape)
.requires_grad_()
)
if recv_next:
tensor_recv_next = (
gather_split_1d_tensor(tensor_recv_next)
.view(tensor_shape)
.requires_grad_()
)
if scatter_gather_optimization_doable:
if not async_comm:
if recv_prev:
tensor_recv_prev = (
gather_split_1d_tensor(tensor_recv_prev)
.view(tensor_shape)
.requires_grad_()
)
if recv_next:
tensor_recv_next = (
gather_split_1d_tensor(tensor_recv_next)
.view(tensor_shape)
.requires_grad_()
)
else:
def gather_recv_prev_wait():
tensor_recv_prev_req.wait()
# From @Deepak's PR https://github.com/NVIDIA/Megatron-LM/commit/27fc468964064eeb33b703c9a0b2af938d80dd14
# A sync seems to be needed before gather otherwise losses jump around e.g., in run_gpt_minimal_test
torch.cuda.synchronize()
return (
gather_split_1d_tensor(tensor_recv_prev)
.view(tensor_shape)
.requires_grad_()
)
def gather_recv_next_wait():
tensor_recv_next_req.wait()
torch.cuda.synchronize()
return (
gather_split_1d_tensor(tensor_recv_next)
.view(tensor_shape)
.requires_grad_()
)
tensor_recv_prev_waitfunc = gather_recv_prev_wait
tensor_recv_next_waitfunc = gather_recv_next_wait
if async_comm:
future_tensor_recv_prev = None
future_tensor_recv_next = None
if tensor_recv_prev is not None:
future_tensor_recv_prev = FutureTensor(tensor_recv_prev, tensor_recv_prev_waitfunc)
if tensor_recv_next is not None:
future_tensor_recv_next = FutureTensor(tensor_recv_next, tensor_recv_next_waitfunc)
return future_tensor_recv_prev, future_tensor_recv_next
return tensor_recv_prev, tensor_recv_next
def recv_forward(
tensor_shape: Shape,
override_scatter_gather_tensors_in_pipeline: bool = False,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> torch.Tensor:
tensor_shape: Shape,
override_scatter_gather_tensors_in_pipeline: bool = False,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Receive tensor from previous rank in pipeline (forward receive)."""
if parallel_state.is_pipeline_first_stage():
return None
if timers is not None:
timers("forward-recv").start()
# if timers is not None:
# timers("forward-recv").start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
......@@ -199,50 +339,58 @@ def recv_forward(
recv_next=False,
tensor_shape=tensor_shape,
override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline,
dtype_=_get_current_dtype(dtype),
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
if timers is not None:
timers("forward-recv").stop()
# if timers is not None:
# timers("forward-recv").stop()
return input_tensor
def recv_backward(
tensor_shape: Shape = None,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
):
tensor_shape: Shape = None,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Receive tensor from next rank in pipeline (backward receive)."""
if parallel_state.is_pipeline_last_stage():
return None
if timers is not None:
timers("backward-recv").start()
# if timers is not None:
# timers("backward-recv").start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
if timers is not None:
timers("backward-recv").stop()
# if timers is not None:
# timers("backward-recv").stop()
return output_tensor_grad
def send_forward(
output_tensor: torch.Tensor,
override_scatter_gather_tensors_in_pipeline: bool = False,
tensor_shape: Shape = None,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
output_tensor: torch.Tensor,
override_scatter_gather_tensors_in_pipeline: bool = False,
tensor_shape: Shape = None,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> None:
"""Send tensor to next rank in pipeline (forward send)."""
if parallel_state.is_pipeline_last_stage():
return
if timers is not None:
timers("forward-send").start()
# if timers is not None:
# timers("forward-send").start()
_communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
......@@ -250,155 +398,181 @@ def send_forward(
recv_next=False,
override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
if timers is not None:
timers("forward-send").stop()
# if timers is not None:
# timers("forward-send").stop()
def send_backward(
input_tensor_grad: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
input_tensor_grad: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> None:
"""Send tensor to previous rank in pipeline (backward send)."""
if parallel_state.is_pipeline_first_stage():
return
if timers is not None:
timers("backward-send").start()
# if timers is not None:
# timers("backward-send").start()
_communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
if timers is not None:
timers("backward-send").stop()
# if timers is not None:
# timers("backward-send").stop()
def send_forward_recv_backward(
output_tensor: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> None:
output_tensor: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Batched send and recv with next rank in pipeline."""
if parallel_state.is_pipeline_last_stage():
return None
if timers is not None:
timers("forward-send-backward-recv").start()
# if timers is not None:
# timers("forward-send-backward-recv").start()
_, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
if timers is not None:
timers("forward-send-backward-recv").stop()
# if timers is not None:
# timers("forward-send-backward-recv").stop()
return output_tensor_grad
def send_backward_recv_forward(
input_tensor_grad: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> torch.Tensor:
input_tensor_grad: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Batched send and recv with previous rank in pipeline."""
if parallel_state.is_pipeline_first_stage():
return None
if timers is not None:
timers("backward-send-forward-recv").start()
# if timers is not None:
# timers("backward-send-forward-recv").start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=True,
recv_next=False,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
if timers is not None:
timers("backward-send-forward-recv").stop()
# if timers is not None:
# timers("backward-send-forward-recv").stop()
return input_tensor
def send_forward_recv_forward(
output_tensor: torch.Tensor,
recv_prev: bool,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> torch.Tensor:
output_tensor: torch.Tensor,
recv_prev: bool,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor]:
"""Batched recv from previous rank and send to next rank in pipeline."""
if timers is not None:
timers("forward-send-forward-recv").start()
# if timers is not None:
# timers("forward-send-forward-recv").start()
input_tensor, _ = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=False,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
if timers is not None:
timers("forward-send-forward-recv").stop()
# if timers is not None:
# timers("forward-send-forward-recv").stop()
return input_tensor
def send_backward_recv_backward(
input_tensor_grad: torch.Tensor,
recv_next: bool,
tensor_shape: Shape,
*,
dtype: torch.dtype = torch.float,
timers: _Timers = None,
) -> torch.Tensor:
input_tensor_grad: torch.Tensor,
recv_next: bool,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor]:
"""Batched recv from next rank and send to previous rank in pipeline."""
if timers is not None:
timers("backward-send-backward-recv").start()
# if timers is not None:
# timers("backward-send-backward-recv").start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
if timers is not None:
timers("backward-send-backward-recv").stop()
# if timers is not None:
# timers("backward-send-backward-recv").stop()
return output_tensor_grad
def send_forward_backward_recv_forward_backward(
output_tensor: torch.Tensor,
input_tensor_grad: torch.Tensor,
recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
):
output_tensor: torch.Tensor,
input_tensor_grad: torch.Tensor,
recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> Tuple[Union[torch.Tensor, FutureTensor], Union[torch.Tensor, FutureTensor]]:
"""Batched send and recv with previous and next ranks in pipeline."""
if timers is not None:
timers("forward-backward-send-forward-backward-recv").start()
# if timers is not None:
# timers("forward-backward-send-forward-backward-recv").start()
input_tensor, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
if timers is not None:
timers("forward-backward-send-forward-backward-recv").stop()
# if timers is not None:
# timers("forward-backward-send-forward-backward-recv").stop()
return input_tensor, output_tensor_grad
import warnings
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import _forward_backward_pipelining_with_interleaving
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import (
forward_backward_no_pipelining,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import (
_forward_backward_pipelining_with_interleaving,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import (
forward_backward_pipelining_without_interleaving,
)
__all__ = [
"get_forward_backward_func",
]
class ExperimentalWarning(Warning):
pass
......@@ -21,19 +27,9 @@ def get_forward_backward_func(
if get_num_microbatches() % pipeline_model_parallel_size != 0:
msg = "number of microbatches is not divisible by pipeline-parallel size when using interleaved schedule"
raise RuntimeError(msg)
warnings.warn(
"Pipeline Model Parallel with interleaving scheduling is experimental. "
f"To use Pipeline Parallel without interleaving, set `virtual_pipeline_model_parallel_size` to `None`: {virtual_pipeline_model_parallel_size}",
ExperimentalWarning
)
forward_backward_func = _forward_backward_pipelining_with_interleaving
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
return forward_backward_func
__all__ = [
"get_forward_backward_func",
]
# NOTE (mkozuki): For simplicity, tentatively `timers` related operations are commented out.
from typing import Any, Callable, Dict, List, Tuple, Union, Optional
from typing import Any, Callable, Dict, List, Tuple, Union, Optional, Sequence
import torch
from torch.autograd.variable import Variable
from apex.normalization.fused_layer_norm import FusedLayerNorm
from apex.transformer import parallel_state
from apex.transformer.enums import ModelType
from apex.transformer.pipeline_parallel.p2p_communication import FutureTensor
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import unwrap_model
from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes
from apex.transformer.pipeline_parallel.utils import get_model_type
from apex.transformer.tensor_parallel.layers import (
set_defaults_if_not_set_tensor_model_parallel_attributes,
)
from apex.transformer.log_util import get_transformer_logger
Batch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]
_logger = get_transformer_logger(__name__)
Batch = Union[torch.Tensor, FutureTensor, List[Union[torch.Tensor, FutureTensor]], Tuple[Union[torch.Tensor, FutureTensor], ...]]
LossFunc = Callable[[torch.Tensor], torch.Tensor]
FwdStepFunc = Callable[[Batch, torch.nn.Module], Tuple[torch.Tensor, LossFunc]]
FwdStepFunc = Callable[
[Optional[Batch], torch.nn.Module], Tuple[torch.Tensor, LossFunc]
]
def build_model(
model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module],
wrap_with_ddp: bool = True,
virtual_pipeline_model_parallel_size: Optional[int] = None,
*args,
**kwargs
model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module],
wrap_with_ddp: bool = True,
virtual_pipeline_model_parallel_size: Optional[int] = None,
model_type: ModelType = ModelType.encoder_or_decoder,
*args: Any,
**kwargs: Any,
) -> List[torch.nn.Module]:
"""Build the model satisfying pipeline model parallel requirements.
......@@ -32,6 +45,7 @@ def build_model(
wrap_with_ddp: If :obj:`True`, wrap the instantiated model
with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`.
virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel.
model_type:
*args: arguments for model provider func
**kwargs: Keyword arguments for model provider func
......@@ -40,8 +54,8 @@ def build_model(
the list has multiple models, otherwise one.
"""
if (
parallel_state.get_pipeline_model_parallel_world_size() > 1 and
virtual_pipeline_model_parallel_size is not None
parallel_state.get_pipeline_model_parallel_world_size() > 1
and virtual_pipeline_model_parallel_size is not None
):
model = []
for i in range(virtual_pipeline_model_parallel_size):
......@@ -51,22 +65,48 @@ def build_model(
# Set pre_process and post_process only after virtual rank is set.
pre_process = parallel_state.is_pipeline_first_stage()
post_process = parallel_state.is_pipeline_last_stage()
cur_kwargs.update({
"pre_process": pre_process,
"post_process": post_process,
})
cur_kwargs.update(
{"pre_process": pre_process, "post_process": post_process,}
)
this_model = model_provider_func(*cur_args, **cur_kwargs)
model.append(this_model)
else:
cur_args = args
cur_kwargs = kwargs
pre_process = parallel_state.is_pipeline_first_stage()
post_process = parallel_state.is_pipeline_last_stage()
cur_kwargs.update({
"pre_process": pre_process,
"post_process": post_process,
})
model = model_provider_func(*cur_args, **cur_kwargs)
if model_type == ModelType.encoder_or_decoder:
pre_process = parallel_state.is_pipeline_first_stage()
post_process = parallel_state.is_pipeline_last_stage()
cur_kwargs.update(
{"pre_process": pre_process, "post_process": post_process,}
)
model = model_provider_func(*cur_args, **cur_kwargs)
elif model_type == ModelType.encoder_and_decoder:
pre_process = parallel_state.is_pipeline_first_stage()
post_process = parallel_state.is_pipeline_last_stage()
# `add_encoder` & `add_decoder` logic.
add_encoder, add_decoder = True, True
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
split_rank = parallel_state.get_pipeline_model_parallel_split_rank()
if split_rank is None:
raise RuntimeError(
"Split rank needs to be specified for model with both encoder and decoder."
)
rank = parallel_state.get_pipeline_model_parallel_rank()
world_size = parallel_state.get_pipeline_model_parallel_world_size()
pre_process = rank == 0 or rank == split_rank
post_process = rank == (split_rank - 1) or rank == (world_size - 1)
add_encoder = parallel_state.is_pipeline_stage_before_split()
add_decoder = parallel_state.is_pipeline_stage_after_split()
cur_kwargs.update(
{
"pre_process": pre_process,
"post_process": post_process,
"add_encoder": add_encoder,
"add_decoder": add_decoder,
}
)
model = model_provider_func(*cur_args, **cur_kwargs)
model.model_type = model_type
if not isinstance(model, list):
model = [model]
......@@ -80,11 +120,14 @@ def build_model(
set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters.
if parallel_state.get_data_parallel_rank() == 0:
if (
parallel_state.model_parallel_is_initialized()
and parallel_state.get_data_parallel_rank() == 0
):
msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format(
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_pipeline_model_parallel_rank(),
sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model])
_calc_number_of_params(model),
)
print(msg, flush=True)
......@@ -106,44 +149,119 @@ def build_model(
return model
def _calc_number_of_params(model: List[torch.nn.Module]) -> int:
assert isinstance(model, list)
return sum(
[
sum([p.nelement() for p in model_module.parameters()])
for model_module in model
]
)
def _get_params_for_weight_decay_optimization(
model: Union[torch.nn.Module, List[torch.nn.Module]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
*,
no_weight_decay_modules=(FusedLayerNorm,),
) -> Dict[str, torch.nn.Parameter]:
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and biases will have no weight decay but the rest will.
"""
modules = listify_model(model)
from apex.normalization.fused_layer_norm import FusedLayerNorm # NOQA
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
weight_decay_params = {"params": []}
no_weight_decay_params = {"params": [], "weight_decay": 0.0}
for module in modules:
for module_ in module.modules():
if isinstance(module_, FusedLayerNorm):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None])
if isinstance(module_, no_weight_decay_modules):
no_weight_decay_params["params"].extend(
[p for p in list(module_._parameters.values()) if p is not None]
)
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
weight_decay_params["params"].extend(
[
p
for n, p in list(module_._parameters.items())
if p is not None and n != "bias"
]
)
no_weight_decay_params["params"].extend(
[
p
for n, p in list(module_._parameters.items())
if p is not None and n == "bias"
]
)
return weight_decay_params, no_weight_decay_params
def free_output_tensor(
output_tensors: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]],
deallocate_pipeline_outputs: bool = False,
) -> None:
"""Pseudo-free the output tensor's `.data` field.
This method should be called right after the output tensor has been sent to the next
pipeline stage. At this point, the output tensor is only useful for its `.grad_fn` field,
and not its `.data`.
"""
if not deallocate_pipeline_outputs:
return
if output_tensors is None:
return
if isinstance(output_tensors, torch.Tensor):
output_tensors = [output_tensors]
for output_tensor in output_tensors:
output_tensor.data = torch.cuda.FloatTensor([0])
def custom_backward(output: torch.Tensor, grad_output: Optional[torch.Tensor]) -> None:
"""Directly call C++ autograd engine.
To make the `free_output_tensor` optimization work, the C++ autograd engine must be called
directly, bypassing PyTorch's `torch.autograd.backward`. PyTorch's `backward` checks that the
output and grad have the same shape, while C++ `backward` does not.
"""
assert (
output.numel() == 1
), "output should be pseudo-freed in schedule, to optimize memory consumption"
assert isinstance(output, torch.Tensor), "output == {}.".format(
type(output).__name__
)
assert isinstance(
grad_output, (torch.Tensor, type(None))
), "grad_outptu == {}.".format(type(grad_output).__name__)
# Handle scalar output
if grad_output is None:
assert output.numel() == 1, "Implicit grad requires scalar output."
grad_output = torch.ones_like(output, memory_format=torch.preserve_format)
# Call C++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Variable._execution_engine.run_backward(
tensors=(output,),
grad_tensors=(grad_output,),
keep_graph=False,
create_graph=False,
inputs=(),
allow_unreachable=True,
accumulate_grad=True,
)
def forward_step(
forward_step_func: FwdStepFunc,
batch: Batch,
model: torch.nn.Module,
input_tensor: Optional[torch.Tensor],
losses_reduced: List[torch.Tensor],
):
forward_step_func: FwdStepFunc,
batch: Optional[Batch],
model: torch.nn.Module,
input_tensor: Optional[Union[torch.Tensor, List[torch.Tensor]]],
losses_reduced: List[torch.Tensor],
dtype: torch.dtype,
disable_autocast: bool = False,
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
If first stage, input tensor is obtained from batch, otherwise passed-in input_tensor is used.
Returns output tensor.
......@@ -154,6 +272,8 @@ def forward_step(
model: unwrappable model
input_tensor:
losses_reduced:
dtype:
disable_autocast:
Returns:
output_tensor
......@@ -161,27 +281,51 @@ def forward_step(
# timers = get_timers()
# timers("forward-compute").start()
unwrapped_model = unwrap_model(model)
model_type = get_model_type(unwrapped_model)
# NOTE (mkozuki): The passed `model` is expected to implement `set_input_tensor`.
# See https://github.com/NVIDIA/Megatron-LM/blob/5ac5571ba0265af4c491ee0af1508ca7589450c6/megatron/model/transformer.py#L679 # NOQA
# for the details of `set_input_tensor`.
unwrap_output_tensor = not isinstance(input_tensor, list)
if unwrap_output_tensor:
input_tensor = [input_tensor]
input_tensor = [inp.get() if isinstance(inp, FutureTensor) else inp for inp in input_tensor]
unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(batch, model)
# print(f"forward_step| pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()} is_pipeline_last_stage?: {parallel_state.is_pipeline_last_stage()}")
if parallel_state.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
with torch.cuda.amp.autocast(
enabled=not disable_autocast and dtype in (torch.half, torch.bfloat16),
dtype=dtype,
):
output_tensor, loss_func = forward_step_func(batch, model)
if parallel_state.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
# timers("forward-compute").stop()
return output_tensor
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
if (
parallel_state.is_pipeline_stage_after_split()
and model_type == ModelType.encoder_and_decoder
):
return [output_tensor, input_tensor[-1]]
if unwrap_output_tensor:
return output_tensor
return [output_tensor]
def backward_step(
input_tensor: Optional[torch.Tensor],
output_tensor: torch.Tensor,
output_tensor_grad: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
input_tensor: Optional[torch.Tensor],
output_tensor: torch.Tensor,
output_tensor_grad: Optional[torch.Tensor],
model_type: ModelType,
*,
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
deallocate_pipeline_outputs: bool = False,
) -> Union[None, torch.Tensor, Sequence[torch.Tensor]]:
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
......@@ -194,25 +338,61 @@ def backward_step(
input_tensor:
output_tensor:
output_tensor_grad:
Keyword Arguments:
grad_scaler:
deallocate_pipeline_outputs: Experimental.
Returns:
input_tensor_grad
"""
# timers = get_timers()
# timers("backward-compute").start()
# Retain the grad on the input_tensor.
unwrap_input_tensor_grad = not isinstance(input_tensor, list)
if unwrap_input_tensor_grad:
input_tensor = [input_tensor]
input_tensor = [inp.get() if isinstance(inp, FutureTensor) else inp for inp in input_tensor]
for x in input_tensor:
if x is not None:
x.retain_grad()
if not isinstance(output_tensor, list):
output_tensor = [output_tensor]
output_tensor = [out.get() if isinstance(out, FutureTensor) else out for out in output_tensor]
if not isinstance(output_tensor_grad, list):
output_tensor_grad = [output_tensor_grad]
output_tensor_grad = [ogr.get() if isinstance(ogr, FutureTensor) else ogr for ogr in output_tensor_grad]
# if parallel_state.get_pipeline_model_parallel_rank() == 0:
# print(f"{input_tensor}, {output_tensor}, {output_tensor_grad}")
if input_tensor is not None:
input_tensor.retain_grad()
# Backward pass.
# if output_tensor_grad is None:
# output_tensor = optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
input_tensor_grad = None
if grad_scaler is not None and output_tensor_grad[0] is None:
output_tensor[0] = grad_scaler.scale(output_tensor[0])
if deallocate_pipeline_outputs:
custom_backward(output_tensor[0], output_tensor_grad[0])
else:
torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
# Collect the grad of the input_tensor.
input_tensor_grad = [None]
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
# timers("backward-compute").stop()
input_tensor_grad = []
for x in input_tensor:
input_tensor_grad.append(None if x is None else x.grad)
return input_tensor_grad
# Handle single skip connection if it exists (encoder_hidden_state in model with encoder and decoder).
if (
parallel_state.get_pipeline_model_parallel_world_size() > 1
and parallel_state.is_pipeline_stage_after_split()
and model_type == ModelType.encoder_and_decoder
):
if output_tensor_grad[1] is not None:
# todo (mkozuki): Replace the inplace add with `+= output_tensor_grad[1]`?
input_tensor_grad[-1].add_(output_tensor_grad[1])
# timers("backward-compute").stop()
return input_tensor_grad[0] if unwrap_input_tensor_grad else input_tensor_grad
from contextlib import contextmanager
from typing import List, Union
from typing import List, Union, Optional
import torch
from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.schedules.common import Batch, FwdStepFunc
from apex.transformer.pipeline_parallel.utils import get_model_type
from apex.transformer.pipeline_parallel.schedules.common import Batch
from apex.transformer.pipeline_parallel.schedules.common import FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.log_util import get_transformer_logger
......@@ -27,12 +29,16 @@ def placeholder_handler():
def forward_backward_no_pipelining(
forward_step_func: FwdStepFunc,
batch: Batch,
model: Union[torch.nn.Module, List[torch.nn.Module]],
*,
forward_only: bool,
**kwargs,
forward_step_func: FwdStepFunc,
batch: Batch,
model: Union[torch.nn.Module, List[torch.nn.Module]],
*,
forward_only: bool,
dtype: Optional[torch.dtype] = None,
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False,
custom_sync_context_handler=None,
**kwargs,
):
"""Run forward and backward passes with no pipeline parallelism (no inter-stage communication).
......@@ -48,6 +54,12 @@ def forward_backward_no_pipelining(
Keyword args:
forward_only:
grad_scaler:
dtype:
disable_autocast: Turn off `enabled` flag of `torch.cuda.amp.autocast` if :obj:`True`.
Should be used when your forward and loss computation is in the autocast context to
avoid unnecesarily nest autocast context.
custom_sync_context_handler:
**kwargs: Added to handle `tensor_shape` which has no effect on this function.
Returns:
......@@ -58,10 +70,14 @@ def forward_backward_no_pipelining(
msg = f"`model` is expected be a `nn.Module`, but {type(model)}"
raise RuntimeError(msg)
model = model[0]
model_type = get_model_type(model)
context_handler = placeholder_handler
if isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel):
if custom_sync_context_handler is not None:
context_handler = custom_sync_context_handler
elif isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel):
context_handler = model.no_sync
else:
context_handler = placeholder_handler
losses_reduced = []
input_tensor, output_tensor_grad = None, None
......@@ -72,20 +88,45 @@ def forward_backward_no_pipelining(
cur_micro_batch = get_kth_microbatch(batch, i)
_logger.debug("Call `forward_step`")
output_tensor = forward_step(
forward_step_func, cur_micro_batch, model, input_tensor, losses_reduced)
forward_step_func,
cur_micro_batch,
model,
input_tensor,
losses_reduced,
dtype=dtype,
disable_autocast=disable_autocast,
)
if not forward_only:
_logger.debug("Call `backward_step`")
backward_step(input_tensor, output_tensor, output_tensor_grad)
backward_step(
input_tensor,
output_tensor,
output_tensor_grad,
model_type=model_type,
grad_scaler=grad_scaler,
)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
_logger.info("Cooldown")
_logger.debug("Call `forward_step`")
output_tensor = forward_step(
forward_step_func, get_kth_microbatch(batch, num_micro_batches - 1), model, input_tensor, losses_reduced
forward_step_func,
get_kth_microbatch(batch, num_micro_batches - 1),
model,
input_tensor,
losses_reduced,
dtype=dtype,
disable_autocast=disable_autocast,
)
if not forward_only:
_logger.debug("Call `backward_step`")
backward_step(input_tensor, output_tensor, output_tensor_grad)
backward_step(
input_tensor,
output_tensor,
output_tensor_grad,
model_type=model_type,
grad_scaler=grad_scaler,
)
return losses_reduced
from typing import List, Union, Optional
from typing import List, Union, Optional, Sequence
import warnings
import torch
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import p2p_communication
from apex.transformer.pipeline_parallel.schedules.common import Batch, FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import Batch
from apex.transformer.pipeline_parallel.schedules.common import FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.schedules.common import free_output_tensor
from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import get_model_type
from apex.transformer.log_util import get_transformer_logger
......@@ -18,15 +22,22 @@ __all__ = ["_forward_backward_pipelining_with_interleaving"]
_logger = get_transformer_logger(__name__)
# TODO (mkozuki): Reduce cyclomatic complexity
# TODO(mkozuki): Reduce cyclomatic complexity
def _forward_backward_pipelining_with_interleaving(
forward_step_func: FwdStepFunc,
batch: List[Batch],
model: List[torch.nn.Module],
*,
forward_only: bool,
tensor_shape: Optional[Union[List[int], torch.Size]] = None,
):
forward_step_func: FwdStepFunc,
batch: List[Optional[Batch]],
model: List[torch.nn.Module],
*,
forward_only: bool,
tensor_shape: Optional[Union[List[int], torch.Size]] = None,
dtype: Optional[torch.dtype] = None,
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False,
deallocate_pipeline_outputs: bool = False,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
**kwargs,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run interleaved 1F1B schedule with communication between pipeline stages as needed.
This function assumes `batch` and `model` is a list of `Batch`'s and a list of `torch.nn.Module`, respectively.
......@@ -48,7 +59,17 @@ def _forward_backward_pipelining_with_interleaving(
Keyword args:
forward_only:
tensor_shape: Shape of tensor.
tensor_shape: Shape of tensor. The tensor is expected to be 3D and its order of dimension
is supposed to be ``(sequence, batch, hidden)``.
dtype: dtype used in p2p communication. If ``None`` (default value),
torch.float32 will be used even if ``autocast`` is enabled.
grad_scaler:
disable_autocast:
deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of
each pipeline stage. Experimental.
sequence_parallel_enabled: Set to :obj:`True` for this function to handle sequence length.
When :obj:`True`, the sequence length on each tensor model parallel rank is updated
to :math:`original\_sequence\_length / tensor\_model\_parallel\_world\_size`.
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
......@@ -56,22 +77,43 @@ def _forward_backward_pipelining_with_interleaving(
if not isinstance(model, list):
raise RuntimeError("`model` must be a list of `nn.Module`'s'")
num_model_chunks = len(model)
input_tensors = [[] for _ in range(num_model_chunks)]
output_tensors = [[] for _ in range(num_model_chunks)]
curr_iters = [0 for _ in range(num_model_chunks)]
losses_reduced = []
if deallocate_pipeline_outputs:
warnings.warn(
"`deallocate_pipeline_outputs` is experimental and subject to change. "
"This option is not recommended."
)
# mypy will blame the following if statement
if sequence_parallel_enabled:
seq_length, batch_size, hidden = tensor_shape
tensor_shape = (
seq_length // parallel_state.get_tensor_model_parallel_world_size(),
batch_size,
hidden,
)
num_model_chunks: int = len(model)
input_tensors: List[List[Union[None, torch.Tensor]]] = [
[] for _ in range(num_model_chunks)
]
output_tensors: List[List[Union[None, torch.Tensor]]] = [
[] for _ in range(num_model_chunks)
]
curr_iters: List[int] = [0 for _ in range(num_model_chunks)]
losses_reduced: List[Union[None, torch.Tensor]] = []
if not forward_only:
output_tensor_grads = [[] for _ in range(num_model_chunks)]
output_tensor_grads: List[List[Union[None, torch.Tensor]]] = [
[] for _ in range(num_model_chunks)
]
pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()
pipeline_parallel_size: int = parallel_state.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank: int = parallel_state.get_pipeline_model_parallel_rank()
# Compute number of warmup and remaining microbatches.
num_microbatches = get_num_microbatches() * num_model_chunks
all_warmup_microbatches = False
num_microbatches: int = get_num_microbatches() * num_model_chunks
all_warmup_microbatches: bool = False
if forward_only:
num_warmup_microbatches = num_microbatches
num_warmup_microbatches: int = num_microbatches
else:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
......@@ -83,10 +125,12 @@ def _forward_backward_pipelining_with_interleaving(
num_warmup_microbatches = num_microbatches
all_warmup_microbatches = True
else:
num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
num_warmup_microbatches = (
pipeline_parallel_size - pipeline_parallel_rank - 1
) * 2
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches
_logger.info(
f"num_microbatches: {num_microbatches}, "
......@@ -100,24 +144,26 @@ def _forward_backward_pipelining_with_interleaving(
def get_model_chunk_id(microbatch_id: int, forward: bool) -> int:
"""Helper function to get the model chunk ID given the iteration number."""
pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
microbatch_id_in_group = microbatch_id % (
pipeline_parallel_size * num_model_chunks
)
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
if not forward:
model_chunk_id = num_model_chunks - model_chunk_id - 1
return model_chunk_id
def forward_step_helper(microbatch_id, curr_iters):
def forward_step_helper(microbatch_id: int, curr_iters: List[int]) -> torch.Tensor:
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
(run set_virtual_pipeline_model_parallel_rank() before calling forward_step()).
"""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
# forward step
if (
parallel_state.is_pipeline_first_stage() and
len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id])
):
if parallel_state.is_pipeline_first_stage() and len(
input_tensors[model_chunk_id]
) == len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1]
output_tensor = forward_step(
......@@ -126,6 +172,8 @@ def _forward_backward_pipelining_with_interleaving(
model[model_chunk_id],
input_tensor,
losses_reduced,
dtype,
disable_autocast,
)
curr_iters[model_chunk_id] += 1
output_tensors[model_chunk_id].append(output_tensor)
......@@ -137,11 +185,13 @@ def _forward_backward_pipelining_with_interleaving(
return output_tensor
def backward_step_helper(microbatch_id):
def backward_step_helper(microbatch_id: int) -> torch.Tensor:
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
(run set_virtual_pipeline_model_parallel_rank() before calling backward_step()).
"""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
model_type = get_model_type(model[model_chunk_id])
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if parallel_state.is_pipeline_last_stage():
......@@ -150,7 +200,14 @@ def _forward_backward_pipelining_with_interleaving(
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad)
input_tensor_grad = backward_step(
input_tensor,
output_tensor,
output_tensor_grad,
model_type=model_type,
grad_scaler=grad_scaler,
deallocate_pipeline_outputs=deallocate_pipeline_outputs,
)
return input_tensor_grad
......@@ -158,7 +215,14 @@ def _forward_backward_pipelining_with_interleaving(
# Run warmup forward passes.
###################################################################################################################
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(p2p_communication.recv_forward(tensor_shape=tensor_shape))
input_tensors[0].append(
p2p_communication.recv_forward(
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
)
_logger.info("Warmup phase")
for k in range(num_warmup_microbatches):
_logger.debug(f"warmup iter: {k} / {num_warmup_microbatches}")
......@@ -172,7 +236,9 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev = False
if k == (num_microbatches - 1):
recv_prev = False
_logger.debug(f"next fwd model chunk ID: {next_forward_model_chunk_id}, recv_prev: {recv_prev}")
_logger.debug(
f"next fwd model chunk ID: {next_forward_model_chunk_id}, recv_prev: {recv_prev}"
)
# Don't send tensor downstream if on last stage.
if parallel_state.is_pipeline_last_stage():
......@@ -181,7 +247,11 @@ def _forward_backward_pipelining_with_interleaving(
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if k == (num_warmup_microbatches - 1) and not forward_only and not all_warmup_microbatches:
if (
k == (num_warmup_microbatches - 1)
and not forward_only
and not all_warmup_microbatches
):
input_tensor_grad = None
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
......@@ -196,12 +266,23 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
else:
_logger.debug("send fwd and receive fwd")
input_tensor = p2p_communication.send_forward_recv_forward(output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape)
input_tensor = p2p_communication.send_forward_recv_forward(
output_tensor,
recv_prev=recv_prev,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)
###################################################################################################################
# Run 1F1B in steady state.
......@@ -229,7 +310,9 @@ def _forward_backward_pipelining_with_interleaving(
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
_logger.debug(f"fwd/bwd model chunk id: {forward_model_chunk_id}/{backward_model_chunk_id}")
_logger.debug(
f"fwd/bwd model chunk id: {forward_model_chunk_id}/{backward_model_chunk_id}"
)
if parallel_state.is_pipeline_first_stage():
input_tensor_grad = None
......@@ -245,7 +328,9 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev = False
next_forward_model_chunk_id += 1
else:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
next_forward_model_chunk_id = get_model_chunk_id(
forward_k + 1, forward=True
)
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
......@@ -257,7 +342,9 @@ def _forward_backward_pipelining_with_interleaving(
recv_next = False
next_backward_model_chunk_id -= 1
else:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
next_backward_model_chunk_id = get_model_chunk_id(
backward_k + 1, forward=False
)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
......@@ -275,7 +362,11 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
......@@ -290,9 +381,18 @@ def _forward_backward_pipelining_with_interleaving(
_logger.info("Cooldown phase")
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks - 1].append(p2p_communication.recv_backward(tensor_shape=tensor_shape))
output_tensor_grads[num_model_chunks - 1].append(
p2p_communication.recv_backward(
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
)
for k in range(num_microbatches_remaining, num_microbatches):
_logger.debug(f"cooldown iter {k} in range({num_microbatches_remaining}, {num_microbatches})")
_logger.debug(
f"cooldown iter {k} in range({num_microbatches_remaining}, {num_microbatches})"
)
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
recv_next = True
......@@ -302,7 +402,14 @@ def _forward_backward_pipelining_with_interleaving(
if k == (num_microbatches - 1):
recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward(input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape)
p2p_communication.send_backward_recv_backward(
input_tensor_grad,
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
)
return losses_reduced
from typing import Union, List, Optional
from typing import Union, List, Optional, Sequence
import warnings
import torch
from apex.transformer import parallel_state
from apex.transformer.enums import ModelType
from apex.transformer.pipeline_parallel import p2p_communication
from apex.transformer.pipeline_parallel.p2p_communication import FutureTensor
from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.schedules.common import Batch, FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.utils import get_model_type
from apex.transformer.pipeline_parallel.schedules.common import Batch
from apex.transformer.pipeline_parallel.schedules.common import FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.schedules.common import free_output_tensor
from apex.transformer.log_util import get_transformer_logger
......@@ -19,14 +25,222 @@ __all__ = ["forward_backward_pipelining_without_interleaving"]
_logger = get_transformer_logger(__name__)
def get_tensor_shapes(
rank: int,
model_type: ModelType,
*,
tensor_shape: Union[List[int], torch.Size],
decoder_sequence_length: Optional[int] = None,
sequence_parallel_enabled: bool = False,
) -> Sequence[Sequence[int]]:
"""Get tensors shapes
Args:
rank: pipeline parallel rank
model_type:
Keyword Args:
tensor_shape:
decoder_sequence_length:
sequence_parallel_enabled:
"""
# Determine right tensor sizes (based on position of rank with respect to split
# rank) and model size.
# Send two tensors if model is T5 and rank is in decoder stage:
# first tensor is decoder (pre-transpose),
# second tensor is encoder (post-transpose).
# If model is T5 and rank is at the boundary:
# send one tensor (post-transpose from encoder).
# Otherwise, send one tensor (pre-transpose).
assert (
len(tensor_shape) == 3
), f"`tensor_shape` should be [sequence_length, micro_batch_size, hidden_size] but {tensor_shape}"
sequence_length, micro_batch_size, hidden_size = tensor_shape
tensor_shapes = []
if sequence_parallel_enabled:
seq_length = sequence_length // parallel_state.get_tensor_model_parallel_world_size()
else:
seq_length = sequence_length
if model_type == ModelType.encoder_and_decoder:
if sequence_parallel_enabled:
dec_seq_length = decoder_sequence_length // parallel_state.get_tensor_model_parallel_world_size()
else:
dec_seq_length = decoder_sequence_length
if parallel_state.is_pipeline_stage_before_split(rank):
tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
else:
tensor_shapes.append((dec_seq_length, micro_batch_size, hidden_size))
tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
else:
tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
return tensor_shapes
def recv_forward(
tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
) -> List[Union[None, torch.Tensor, FutureTensor]]:
input_tensors = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
input_tensors.append(None)
else:
input_tensors.append(
p2p_communication.recv_forward(
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
)
return input_tensors
def recv_backward(
tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
) -> List[Union[None, torch.Tensor, FutureTensor]]:
output_tensor_grads = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
output_tensor_grads.append(None)
else:
output_tensor_grads.append(
p2p_communication.recv_backward(
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
)
return output_tensor_grads
def send_forward(
output_tensors: Union[torch.Tensor, List[Union[None, torch.Tensor]]],
tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
) -> None:
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_forward(
output_tensor,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
def send_backward(
input_tensor_grads: Union[torch.Tensor, List[Union[None, torch.Tensor]]],
tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
) -> None:
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_backward(
input_tensor_grad,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
def send_forward_recv_backward(
output_tensors: Union[torch.Tensor, List[Union[None, torch.Tensor]]],
tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
) -> List[Union[None, torch.Tensor, FutureTensor]]:
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
output_tensor_grads = []
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
output_tensor_grads.append(None)
continue
output_tensor_grad = p2p_communication.send_forward_recv_backward(
output_tensor,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
output_tensor_grads.append(output_tensor_grad)
return output_tensor_grads
def send_backward_recv_forward(
input_tensor_grads: Union[torch.Tensor, List[Union[None, torch.Tensor]]],
tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
) -> List[Union[None, torch.Tensor, FutureTensor]]:
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
input_tensors = []
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
input_tensors.append(None)
continue
input_tensor = p2p_communication.send_backward_recv_forward(
input_tensor_grad,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
input_tensors.append(input_tensor)
return input_tensors
def forward_backward_pipelining_without_interleaving(
forward_step_func: FwdStepFunc,
batch: Batch,
model: Union[torch.nn.Module, List[torch.nn.Module]],
*,
forward_only: bool,
tensor_shape: Optional[Union[List[int], torch.Size]] = None,
):
forward_step_func: FwdStepFunc,
batch: Optional[Batch],
model: Union[torch.nn.Module, List[torch.nn.Module]],
*,
forward_only: bool,
tensor_shape: Optional[Union[List[int], torch.Size]] = None,
decoder_sequence_length: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False,
deallocate_pipeline_outputs: bool = False,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
**kwargs,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run non-interleaved 1F1B schedule, with communication between pipeline stages.
This pipeline parallel scheduling consists of three steps:
......@@ -44,28 +258,59 @@ def forward_backward_pipelining_without_interleaving(
Keyword args:
forward_only:
tensor_shape: Shape of tensor. Required for P2P communication.
tensor_shape: Shape of tensor. The tensor is expected to be 3D and its order of dimension
is supposed to be ``(sequence, batch, hidden)``.
dtype: dtype used in p2p communication. If ``None`` (default value),
torch.float32 will be used even if ``autocast`` is enabled.
grad_scaler:
disable_autocast:
deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of
each pipeline stage. Experimental.
sequence_parallel_enabled: Set to :obj:`True` for this function to handle sequence length.
When :obj:`True`, the sequence length on each tensor model parallel rank is updated
to :math:`original\_sequence\_length / tensor\_model\_parallel\_world\_size`.
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
"""
# timers = get_timers()
model = listify_model(model)
if deallocate_pipeline_outputs:
warnings.warn(
"`deallocate_pipeline_outputs` is experimental and subject to change. "
"This option is not recommended."
)
model: List[torch.nn.Module] = listify_model(model)
if len(model) != 1:
msg = f"`model` is expected be a `nn.Module`, but {type(model)}"
raise RuntimeError(msg)
model = model[0]
model: torch.nn.Module = model[0]
# Compute number of warmup microbatches.
num_microbatches = get_num_microbatches()
num_warmup_microbatches = (
parallel_state.get_pipeline_model_parallel_world_size()
- parallel_state.get_pipeline_model_parallel_rank()
- 1
num_microbatches: int = get_num_microbatches()
num_warmup_microbatches: int = (
parallel_state.get_pipeline_model_parallel_world_size() - parallel_state.get_pipeline_model_parallel_rank() - 1
)
num_warmup_microbatches: int = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches
model_type = get_model_type(model)
rank: int = parallel_state.get_pipeline_model_parallel_rank()
recv_tensor_shapes: List[List[int]] = get_tensor_shapes(
rank - 1,
model_type,
tensor_shape=tensor_shape,
decoder_sequence_length=decoder_sequence_length,
sequence_parallel_enabled=sequence_parallel_enabled,
)
send_tensor_shapes: List[List[int]] = get_tensor_shapes(
rank,
model_type,
tensor_shape=tensor_shape,
decoder_sequence_length=decoder_sequence_length,
sequence_parallel_enabled=sequence_parallel_enabled,
)
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
_logger.info(
f"num_microbatches: {num_microbatches}, "
......@@ -74,13 +319,9 @@ def forward_backward_pipelining_without_interleaving(
)
# Input, output tensors only need to be saved when doing backward passes
input_tensors = None
output_tensors = None
if not forward_only:
input_tensors = []
output_tensors = []
losses_reduced = []
input_tensors: List[Union[None, torch.Tensor]] = []
output_tensors: List[Union[None, torch.Tensor]] = []
losses_reduced: List[Union[None, torch.Tensor]] = []
###################################################################################################################
# Run warmup forward passes.
###################################################################################################################
......@@ -88,22 +329,42 @@ def forward_backward_pipelining_without_interleaving(
for i in range(num_warmup_microbatches):
_logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}")
_logger.debug("receive fwd")
input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape)
cur_microbatch = get_kth_microbatch(batch, i)
output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced)
input_tensor = recv_forward(
tensor_shapes=recv_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i)
output_tensor = forward_step(
forward_step_func,
cur_microbatch,
model,
input_tensor,
losses_reduced,
dtype,
disable_autocast,
)
_logger.debug("send fwd")
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape)
send_forward(
output_tensor,
tensor_shapes=send_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
if not forward_only:
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
_logger.debug("recv_forward before steady state start")
input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape)
input_tensor: List[Union[None, torch.Tensor, FutureTensor]] = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm)
###################################################################################################################
# Run 1F1B in steady state.
......@@ -111,42 +372,84 @@ def forward_backward_pipelining_without_interleaving(
_logger.info("Steady phase")
for i in range(num_microbatches_remaining):
_logger.debug(f"steady iter: {i} / {num_microbatches_remaining}")
last_iteration = i == (num_microbatches_remaining - 1)
last_iteration: bool = i == (num_microbatches_remaining - 1)
cur_microbatch = get_kth_microbatch(batch, i + num_warmup_microbatches)
output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced)
cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i + num_warmup_microbatches)
output_tensor: Union[torch.Tensor, Sequence[torch.Tensor]] = forward_step(
forward_step_func,
cur_microbatch,
model,
input_tensor,
losses_reduced,
dtype,
disable_autocast,
)
if forward_only:
_logger.debug("send fwd")
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape)
send_forward(
output_tensor,
tensor_shapes=send_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
if not last_iteration:
_logger.debug("receive fwd (last iteration)")
input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape)
input_tensor = recv_forward(
tensor_shapes=recv_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
else:
_logger.debug("send fwd & receive bwd")
output_tensor_grad = p2p_communication.send_forward_recv_backward(output_tensor, tensor_shape=tensor_shape)
output_tensor_grad = send_forward_recv_backward(
output_tensor,
tensor_shapes=send_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)
# Pop input_tensor and output_tensor from the start of the list for the backward pass.
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad
input_tensor,
output_tensor,
output_tensor_grad,
model_type=model_type,
grad_scaler=grad_scaler,
deallocate_pipeline_outputs=deallocate_pipeline_outputs,
)
if last_iteration:
input_tensor = None
_logger.debug("send bwd")
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape)
send_backward(
input_tensor_grad,
tensor_shapes=recv_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
else:
_logger.debug("send bwd and receive fwd")
input_tensor = p2p_communication.send_backward_recv_forward(
input_tensor_grad, tensor_shape=tensor_shape)
input_tensor = send_backward_recv_forward(
input_tensor_grad,
tensor_shapes=recv_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
###################################################################################################################
# Run cooldown backward passes.
###################################################################################################################
......@@ -158,13 +461,29 @@ def forward_backward_pipelining_without_interleaving(
output_tensor = output_tensors.pop(0)
_logger.debug("receive bwd")
output_tensor_grad = p2p_communication.recv_backward(tensor_shape=tensor_shape)
output_tensor_grad = recv_backward(
tensor_shapes=send_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad
input_tensor,
output_tensor,
output_tensor_grad,
model_type=model_type,
grad_scaler=grad_scaler,
deallocate_pipeline_outputs=deallocate_pipeline_outputs,
)
_logger.debug("send bwd")
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape)
send_backward(
input_tensor_grad,
tensor_shapes=recv_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
return losses_reduced
......@@ -21,6 +21,7 @@ from torch.nn.parallel import DistributedDataParallel
from apex.multi_tensor_apply import multi_tensor_applier
from apex.transformer import parallel_state
from apex.transformer.enums import ModelType
from apex.transformer.microbatches import build_num_microbatches_calculator
from apex.transformer.pipeline_parallel._timers import _Timers
if multi_tensor_applier.available:
......@@ -118,14 +119,24 @@ def _split_batch_into_microbatch(
# TODO(mkozuki): Support non-tensor local minibatches?
def get_kth_microbatch(batch: List[torch.Tensor], k: int) -> List[torch.Tensor]:
def get_kth_microbatch(batch: Optional[List[torch.Tensor]], k: int) -> List[torch.Tensor]:
"""Create a list of microbatches from a list of local minibatches.
This function creates a list of `k`th microbatches from a list of local minibatches.
`a local minibatch` consists of `global_batch_size / data_parallel_size` samples.
"""
if batch is None:
return batch
micro_batch_size = get_micro_batch_size()
return [x[k * micro_batch_size:(k + 1) * micro_batch_size] for x in batch]
start = k * micro_batch_size
end = start + micro_batch_size
microbatch = list()
for x in batch:
size = x.size(0)
assert size > start and size >= end
microbatch.append(x[start:end])
assert len(microbatch) > 0
return microbatch
def get_autoresume():
......@@ -186,6 +197,19 @@ def unwrap_model(model, module_instances=(DistributedDataParallel,)):
return unwrapped_model
def get_model_type(
model: torch.nn.Module,
) -> ModelType:
"""Get `model_type` of `model`.
If ``model`` doesn't have ``model_type`` attribute, return ``ModelType.encoder_or_decoder``.
Args:
model
"""
return getattr(unwrap_model(model), "model_type", ModelType.encoder_or_decoder)
def calc_params_l2_norm(model: torch.nn.Module, bf16: bool):
"""Calculate l2 norm of parameters """
# args = get_args()
......
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -33,6 +32,7 @@ from apex.transformer.tensor_parallel.mappings import (
gather_from_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region,
scatter_to_sequence_parallel_region,
)
from .random import (
......@@ -63,6 +63,7 @@ __all__ = [
"gather_from_tensor_model_parallel_region",
"reduce_from_tensor_model_parallel_region",
"scatter_to_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region",
# random.py
"checkpoint",
"get_cuda_rng_tracker",
......
......@@ -25,8 +25,9 @@ _MAX_DATA_DIM = 5
def _check_data_types(keys, data, target_dtype):
"""Check that all the keys have the same target data type."""
for key in keys:
assert data[key].dtype == target_dtype, "{} has data type {} which " "is different than {}".format(
key, data[key].dtype, target_dtype
assert data[key].dtype == target_dtype, (
"{} has data type {} which "
"is different than {}".format(key, data[key].dtype, target_dtype)
)
......@@ -48,7 +49,9 @@ def _build_key_size_numel_dictionaries(keys, data):
# Move to GPU and broadcast.
sizes_cuda = torch.cuda.LongTensor(sizes)
torch.distributed.broadcast(
sizes_cuda, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group(),
sizes_cuda,
get_tensor_model_parallel_src_rank(),
group=get_tensor_model_parallel_group(),
)
# Move back to cpu and unpack.
......@@ -92,13 +95,19 @@ def broadcast_data(keys, data, datatype):
# Check that all keys have the same data type.
_check_data_types(keys, data, datatype)
# Flatten the data associated with the keys
flatten_data = torch.cat([data[key].contiguous().view(-1) for key in keys], dim=0).cuda()
flatten_data = torch.cat(
[data[key].contiguous().view(-1) for key in keys], dim=0
).cuda()
else:
flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype)
flatten_data = torch.empty(
total_numel, device=torch.cuda.current_device(), dtype=datatype
)
# Broadcast
torch.distributed.broadcast(
flatten_data, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group(),
flatten_data,
get_tensor_model_parallel_src_rank(),
group=get_tensor_model_parallel_group(),
)
# Unpack
......
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -16,6 +16,9 @@
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
from typing import Optional, Dict, Tuple, List
import warnings
import torch
import torch.nn.functional as F
import torch.nn.init as init
......@@ -26,12 +29,34 @@ from apex.transformer.parallel_state import get_tensor_model_parallel_group
from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from apex.transformer.parallel_state import get_tensor_model_parallel_world_size
from apex.transformer.utils import divide
from apex.transformer.tensor_parallel.mappings import copy_to_tensor_model_parallel_region
from apex.transformer.tensor_parallel.mappings import gather_from_tensor_model_parallel_region
from apex.transformer.tensor_parallel.mappings import reduce_from_tensor_model_parallel_region
from apex.transformer.tensor_parallel.mappings import scatter_to_tensor_model_parallel_region
from apex.transformer.tensor_parallel.mappings import (
copy_to_tensor_model_parallel_region,
)
from apex.transformer.tensor_parallel.mappings import (
gather_from_tensor_model_parallel_region,
)
from apex.transformer.tensor_parallel.mappings import (
reduce_from_tensor_model_parallel_region,
)
from apex.transformer.tensor_parallel.mappings import (
scatter_to_tensor_model_parallel_region,
)
from apex.transformer.tensor_parallel.mappings import (
reduce_scatter_to_sequence_parallel_region,
)
from apex.transformer.tensor_parallel.random import get_cuda_rng_tracker
from apex.transformer.tensor_parallel.utils import VocabUtility
from apex.transformer.log_util import get_transformer_logger
_logger = get_transformer_logger(__name__)
_grad_accum_fusion_available = True
try:
import fused_weight_gradient_mlp_cuda
except ImportError:
_grad_accum_fusion_available = False
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
......@@ -41,13 +66,13 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
}
def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel) or (
get_tensor_model_parallel_rank() == 0
)
def param_is_not_tensor_parallel_duplicate(param: torch.Tensor) -> bool:
return (
hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel
) or (get_tensor_model_parallel_rank() == 0)
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
def set_tensor_model_parallel_attributes(tensor: torch.Tensor, is_parallel: bool, dim: int, stride: int) -> None:
# Make sure the attributes are not set.
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
assert not hasattr(tensor, attribute)
......@@ -57,7 +82,7 @@ def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
setattr(tensor, "partition_stride", stride)
def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor: torch.Tensor) -> None:
def maybe_set(attribute, value):
if not hasattr(tensor, attribute):
setattr(tensor, attribute, value)
......@@ -66,7 +91,7 @@ def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])
def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
def copy_tensor_model_parallel_attributes(destination_tensor: torch.Tensor, source_tensor: torch.Tensor) -> None:
def maybe_copy(attribute):
if hasattr(source_tensor, attribute):
setattr(destination_tensor, attribute, getattr(source_tensor, attribute))
......@@ -76,9 +101,18 @@ def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
def _initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU."""
"""Initialize affine weight for model parallel on GPU.
set_tensor_model_parallel_attributes(tensor=weight, is_parallel=True, dim=partition_dim, stride=stride)
Args:
weight (Parameter):
init_method (Callable[[Tensor], None]): Taking a Tensor and initialize its elements.
partition_dim (int): Dimension to apply partition.
stride (int):
"""
set_tensor_model_parallel_attributes(
tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
)
with get_cuda_rng_tracker().fork():
init_method(weight)
......@@ -103,16 +137,22 @@ def _initialize_affine_weight_cpu(
Build the master weight on all processes and scatter
the relevant chunk."""
set_tensor_model_parallel_attributes(tensor=weight, is_parallel=True, dim=partition_dim, stride=stride)
set_tensor_model_parallel_attributes(
tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
)
# Initialize master weight
master_weight = torch.empty(output_size, input_size, dtype=torch.float, requires_grad=False)
master_weight = torch.empty(
output_size, input_size, dtype=torch.float, requires_grad=False
)
init_method(master_weight)
master_weight = master_weight.to(dtype=params_dtype)
# Split and copy
per_partition_per_stride_size = divide(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim)
weight_list = torch.split(
master_weight, per_partition_per_stride_size, dim=partition_dim
)
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size]
......@@ -136,9 +176,15 @@ class VocabParallelEmbedding(torch.nn.Module):
"""
def __init__(
self, num_embeddings, embedding_dim, init_method=init.xavier_normal_, *, params_dtype=torch.float32, use_cpu_initialization=False,
self,
num_embeddings: int,
embedding_dim: int,
init_method=init.xavier_normal_,
*,
params_dtype: torch.dtype=torch.float32,
use_cpu_initialization: bool = False,
):
super(VocabParallelEmbedding, self).__init__()
super().__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
......@@ -150,19 +196,35 @@ class VocabParallelEmbedding(torch.nn.Module):
self.sparse = False
self._weight = None
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_tensor_model_parallel_rank(), self.tensor_model_parallel_size
# Divide the weight matrix along the vocabulary dimension.
(
self.vocab_start_index,
self.vocab_end_index,
) = VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings,
get_tensor_model_parallel_rank(),
self.tensor_model_parallel_size,
)
self.num_embeddings_per_partition = (
self.vocab_end_index - self.vocab_start_index
)
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
# Allocate weights and initialize.
if use_cpu_initialization:
self.weight = Parameter(
torch.empty(self.num_embeddings_per_partition, self.embedding_dim, dtype=params_dtype)
torch.empty(
self.num_embeddings_per_partition,
self.embedding_dim,
dtype=params_dtype,
)
)
_initialize_affine_weight_cpu(
self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method,
self.weight,
self.num_embeddings,
self.embedding_dim,
self.num_embeddings_per_partition,
0,
init_method,
params_dtype=params_dtype,
)
else:
......@@ -174,12 +236,16 @@ class VocabParallelEmbedding(torch.nn.Module):
dtype=params_dtype,
)
)
_initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1)
_initialize_affine_weight_gpu(
self.weight, init_method, partition_dim=0, stride=1
)
def forward(self, input_):
if self.tensor_model_parallel_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
input_mask = (input_ < self.vocab_start_index) | (
input_ >= self.vocab_end_index
)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
......@@ -203,16 +269,44 @@ class VocabParallelEmbedding(torch.nn.Module):
return output
class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
"""
Column-parallel linear layer execution with asynchronous all-reduce
execution in backprop.
"""
class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
"""Linear layer execution with asynchronous communication and gradient accumulation fusion in backprop."""
@staticmethod
def forward(ctx, input, weight, bias):
def forward(
ctx,
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
gradient_accumulation_fusion: bool,
async_grad_allreduce: bool,
sequence_parallel_enabled: bool,
use_16bit_in_wgrad_accum_fusion: bool = False,
):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
output = torch.matmul(input, weight.t())
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce
ctx.sequence_parallel_enabled = sequence_parallel_enabled
ctx.use_16bit_in_wgrad_accum_fusion = use_16bit_in_wgrad_accum_fusion
if ctx.sequence_parallel_enabled:
world_size = get_tensor_model_parallel_world_size()
# `input` is supposed to be 3D and its order of dimension is [sequence, batch, hidden]
shape = list(input.shape)
shape[0] *= world_size
all_gather_buffer = torch.empty(
shape,
dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
torch.distributed._all_gather_base(all_gather_buffer, input, group=get_tensor_model_parallel_group())
total_input = all_gather_buffer
else:
total_input = input
output = torch.matmul(total_input, weight.t())
if bias is not None:
output = output + bias
return output
......@@ -221,23 +315,115 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
if ctx.sequence_parallel_enabled:
world_size = get_tensor_model_parallel_world_size()
shape = list(input.shape)
shape[0] *= world_size
all_gather_buffer = torch.empty(
shape,
dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
handle = torch.distributed._all_gather_base(
all_gather_buffer,
input,
group=get_tensor_model_parallel_group(),
async_op=True,
)
total_input = all_gather_buffer
else:
total_input = input
grad_input = grad_output.matmul(weight)
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
handle.wait()
return grad_input, grad_weight, grad_bias
if ctx.sequence_parallel_enabled:
handle.wait()
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(
grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
)
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True
)
if ctx.sequence_parallel_enabled:
assert not ctx.async_grad_allreduce
sub_grad_input = torch.empty(input.shape, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False)
handle = torch.distributed._reduce_scatter_base(
sub_grad_input,
grad_input,
group=get_tensor_model_parallel_group(),
async_op=True
)
if ctx.gradient_accumulation_fusion:
if not ctx.use_16bit_in_wgrad_accum_fusion:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
total_input, grad_output, weight.main_grad
)
else:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
total_input, grad_output, weight.main_grad
)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
def column_parallel_linear(input, weight, bias):
args = _cast_if_autocast_enabled(input, weight, bias)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.sequence_parallel_enabled:
handle.wait()
return sub_grad_input, grad_weight, grad_bias, None, None, None, None
if ctx.async_grad_allreduce:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None
def linear_with_grad_accumulation_and_async_allreduce(
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
gradient_accumulation_fusion: bool,
async_grad_allreduce: bool,
sequence_parallel_enabled: bool,
) -> torch.Tensor:
args = _cast_if_autocast_enabled(
input,
weight,
bias,
gradient_accumulation_fusion,
async_grad_allreduce,
sequence_parallel_enabled,
False, # use_16bit_in_wgrad_accum_fusion
)
with torch.cuda.amp.autocast(enabled=False):
return ColumnParallelLinearWithAsyncAllreduce.apply(*args)
return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
def linear_with_grad_accumulation_and_async_allreduce_in16bit(
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
gradient_accumulation_fusion: bool,
async_grad_allreduce: bool,
sequence_parallel_enabled: bool,
) -> torch.Tensor:
args = _cast_if_autocast_enabled(
input,
weight,
bias,
gradient_accumulation_fusion,
async_grad_allreduce,
sequence_parallel_enabled,
True, # use_16bit_in_wgrad_accum_fusion
)
with torch.cuda.amp.autocast(enabled=False):
return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
class ColumnParallelLinear(torch.nn.Module):
......@@ -246,6 +432,10 @@ class ColumnParallelLinear(torch.nn.Module):
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
.. note::
Input is supposed to be three dimensional and each dimension
is expected to be sequence, batch, and hidden feature, respectively.
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
......@@ -262,6 +452,14 @@ class ColumnParallelLinear(torch.nn.Module):
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
Keyword Arguments:
no_async_tensor_model_parallel_allreduce:
params_dtype:
use_cpu_initialization:
gradient_accumulation_fusion:
accumulation_in_fp16:
sequence_parallel_enabled:
"""
def __init__(
......@@ -278,8 +476,11 @@ class ColumnParallelLinear(torch.nn.Module):
no_async_tensor_model_parallel_allreduce=False,
params_dtype=torch.float32,
use_cpu_initialization=False,
gradient_accumulation_fusion=False,
accumulation_in_fp16: bool = False,
sequence_parallel_enabled: bool = False,
):
super(ColumnParallelLinear, self).__init__()
super().__init__()
# Keep input parameters
self.input_size = input_size
......@@ -295,7 +496,9 @@ class ColumnParallelLinear(torch.nn.Module):
# we allocate the transpose.
# Initialize weight.
if use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size_per_partition, self.input_size, dtype=params_dtype))
self.weight = Parameter(
torch.empty(self.output_size_per_partition, self.input_size, dtype=params_dtype)
)
self.master_weight = _initialize_affine_weight_cpu(
self.weight,
self.output_size,
......@@ -323,7 +526,11 @@ class ColumnParallelLinear(torch.nn.Module):
self.bias = Parameter(torch.empty(self.output_size_per_partition, dtype=params_dtype))
else:
self.bias = Parameter(
torch.empty(self.output_size_per_partition, device=torch.cuda.current_device(), dtype=params_dtype)
torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
# Always initialize bias to zero.
......@@ -333,28 +540,69 @@ class ColumnParallelLinear(torch.nn.Module):
self.register_parameter("bias", None)
self.async_tensor_model_parallel_allreduce = (
not no_async_tensor_model_parallel_allreduce and
world_size > 1)
not no_async_tensor_model_parallel_allreduce and world_size > 1
)
if sequence_parallel_enabled:
if world_size <= 1:
warnings.warn(
f"`sequence_parallel_enabled` is set to `True`, but got world_size of {world_size}"
)
# sequence_parallel_enabled = False
self.sequence_parallel_enabled = sequence_parallel_enabled
if gradient_accumulation_fusion:
if not _grad_accum_fusion_available:
# Basically, apex.transformer module users are expected to install APEX's
# `--cpp_ext` and `--cuda_ext`. The example installation command is as follows:
# `pip install --global-option="--cpp_ext" --global-option="--cuda_ext ."
# at the root of APEX repository.
warnings.warn(
"`gradient_accumulation_fusion` is set to `True` but "
"the custom CUDA extension of `fused_weight_gradient_mlp_cuda` module not "
"found. Thus `gradient_accumulation_fusion` set to `False`. "
"Note that the extension requires CUDA>=11."
)
gradient_accumulation_fusion = False
self.gradient_accumulation_fusion = gradient_accumulation_fusion
def forward(self, input_):
if self.async_tensor_model_parallel_allreduce and self.sequence_parallel_enabled:
raise RuntimeError("`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` cannot be enabled at the same time.")
self._forward_impl = (
linear_with_grad_accumulation_and_async_allreduce_in16bit
if accumulation_in_fp16
else linear_with_grad_accumulation_and_async_allreduce
)
def forward(self, input_: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Forward of ColumnParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
Returns:
- output
- bias
"""
bias = self.bias if not self.skip_bias_add else None
if self.async_tensor_model_parallel_allreduce:
input_shape = input_.shape
input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2])
# Matrix multiply with asynchronous all-reduce execution
output_parallel = column_parallel_linear(input_, self.weight, bias)
output_parallel = output_parallel.view(
input_shape[0], input_shape[1], output_parallel.shape[1])
if self.async_tensor_model_parallel_allreduce or self.sequence_parallel_enabled:
input_parallel = input_
else:
# Set up backprop all-reduce.
input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, bias)
# Matrix multiply.
output_parallel = self._forward_impl(
input=input_parallel,
weight=self.weight,
bias=bias,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=self.async_tensor_model_parallel_allreduce,
sequence_parallel_enabled=self.sequence_parallel_enabled,
)
if self.gather_output:
# All-gather across the partitions.
assert not self.sequence_parallel_enabled
output = gather_from_tensor_model_parallel_region(output_parallel)
else:
output = output_parallel
......@@ -374,6 +622,11 @@ class RowParallelLinear(torch.nn.Module):
| . |
| A_p |
- -
.. note::
Input is supposed to be three dimensional and each dimension
is expected to be sequence, batch, and hidden feature, respectively.
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
......@@ -390,6 +643,12 @@ class RowParallelLinear(torch.nn.Module):
skip_bias_add: This was added to enable performance optimization where bias
can be fused with other elementwise operations. We skip
adding bias but instead return it.
Keyword Arguments:
params_dtype:
use_cpu_initialization:
gradient_accumulation_fusion:
accumulation_in_fp16:
sequence_parallel_enabled:
"""
def __init__(
......@@ -405,8 +664,11 @@ class RowParallelLinear(torch.nn.Module):
*,
params_dtype=torch.float32,
use_cpu_initialization=False,
gradient_accumulation_fusion=False,
accumulation_in_fp16: bool = False,
sequence_parallel_enabled: bool = False,
):
super(RowParallelLinear, self).__init__()
super().__init__()
# Keep input parameters
self.input_size = input_size
......@@ -416,6 +678,10 @@ class RowParallelLinear(torch.nn.Module):
world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add
self.gradient_accumulation_fusion = gradient_accumulation_fusion
self.sequence_parallel_enabled = sequence_parallel_enabled
if self.sequence_parallel_enabled and not self.input_is_parallel:
raise RuntimeError("To enable `sequence_parallel_enabled`, `input_is_parallel` must be `True`")
# as an argument to this function?
# Parameters.
......@@ -423,7 +689,11 @@ class RowParallelLinear(torch.nn.Module):
# we allocate the transpose.
# Initialize weight.
if use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size, self.input_size_per_partition, dtype=params_dtype))
self.weight = Parameter(
torch.empty(
self.output_size, self.input_size_per_partition, dtype=params_dtype
)
)
self.master_weight = _initialize_affine_weight_cpu(
self.weight,
self.output_size,
......@@ -444,30 +714,63 @@ class RowParallelLinear(torch.nn.Module):
dtype=params_dtype,
)
)
_initialize_affine_weight_gpu(self.weight, init_method, partition_dim=1, stride=stride)
_initialize_affine_weight_gpu(
self.weight, init_method, partition_dim=1, stride=stride
)
if bias:
if use_cpu_initialization:
self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
else:
self.bias = Parameter(
torch.empty(self.output_size, device=torch.cuda.current_device(), dtype=params_dtype)
torch.empty(
self.output_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
setattr(self.bias, "sequence_parallel_enabled", sequence_parallel_enabled)
else:
self.register_parameter("bias", None)
def forward(self, input_):
self._forward_impl = (
linear_with_grad_accumulation_and_async_allreduce_in16bit
if accumulation_in_fp16
else linear_with_grad_accumulation_and_async_allreduce
)
def forward(self, input_: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Forward of RowParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
Returns:
- output
- bias
"""
# Set up backprop all-reduce.
if self.input_is_parallel:
input_parallel = input_
else:
assert not self.sequence_parallel_enabled
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight)
output_parallel = self._forward_impl(
input=input_parallel,
weight=self.weight,
bias=None,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=False,
sequence_parallel_enabled=False,
)
# All-reduce across all the partitions.
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if self.sequence_parallel_enabled:
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_
output_bias = None
......
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -20,7 +20,7 @@ from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from apex.transformer.tensor_parallel.utils import split_tensor_along_last_dim
def _reduce(input_):
def _reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
......@@ -33,7 +33,7 @@ def _reduce(input_):
return input_
def _split(input_):
def _split_along_last_dim(input_: torch.Tensor) -> torch.Tensor:
"""Split the tensor along its last dimension and keep the
corresponding slice."""
......@@ -52,8 +52,24 @@ def _split(input_):
return output
def _gather(input_):
"""Gather tensors and concatinate along the last dimension."""
def _split_along_first_dim(input_: torch.Tensor) -> torch.Tensor:
"""Split the tensor along its first dimension and keep the corresponding slice."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU for tensor model parallel.
if world_size == 1:
return input_
# Split along first dimension.
dim_size = input_.size(0)
assert dim_size % world_size == 0
local_dim_size = dim_size // world_size
dim_offset = get_tensor_model_parallel_rank() * local_dim_size
output = input_[dim_offset:dim_offset + local_dim_size].contiguous()
return output
def _gather_along_last_dim(input_: torch.Tensor) -> torch.Tensor:
"""Gather tensors and concatenate along the last dimension."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
......@@ -66,7 +82,9 @@ def _gather(input_):
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
torch.distributed.all_gather(
tensor_list, input_, group=get_tensor_model_parallel_group()
)
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=last_dim).contiguous()
......@@ -74,9 +92,49 @@ def _gather(input_):
return output
def _gather_along_first_dim(input_: torch.Tensor) -> torch.Tensor:
"""Gather tensors and concatenate along the first dimension."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
shape = list(input_.shape)
shape[0] *= world_size
output = torch.empty(shape, dtype=input_.dtype, device=torch.cuda.current_device())
torch.distributed._all_gather_base(
output,
input_.contiguous(),
group=get_tensor_model_parallel_group()
)
return output
def _reduce_scatter_along_first_dim(input_: torch.Tensor) -> torch.Tensor:
"""Reduce-scatter the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
shape = list(input_.shape)
assert shape[0] % world_size == 0
shape[0] //= world_size
output = torch.empty(shape, dtype=input_.dtype, device=torch.cuda.current_device())
torch.distributed._reduce_scatter_base(
output,
input_.contiguous(),
group=get_tensor_model_parallel_group()
)
return output
class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""
"""Pass the input to the tensor model parallel region."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@staticmethod
def symbolic(graph, input_):
return input_
......@@ -91,8 +149,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
"""All-reduce the input from the tensor model parallel region."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@staticmethod
def symbolic(graph, input_):
return _reduce(input_)
......@@ -109,33 +169,95 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@staticmethod
def symbolic(graph, input_):
return _split(input_)
return _split_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _split(input_)
return _split_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output)
return _gather_along_last_dim(grad_output)
class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
"""Gather the input from tensor model parallel region and concatenate."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@staticmethod
def symbolic(graph, input_):
return _gather(input_)
return _gather_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _gather(input_)
return _gather_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output)
return _split_along_last_dim(grad_output)
class _ScatterToSequenceParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chunk to the rank."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@staticmethod
def symbolic(graph, input_):
return _split_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _split_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatenate."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@staticmethod
def symbolic(graph, input_, to_model_parallel: bool = True):
return _gather_along_first_dim(input_)
@staticmethod
def forward(ctx, input_, to_model_parallel: bool = True):
ctx.to_model_parallel = to_model_parallel
return _gather_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
if ctx.to_model_parallel:
return _reduce_scatter_along_first_dim(grad_output), None
else:
return _split_along_first_dim(grad_output), None
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
"""Reduce scatter the input from the sequence parallel region and concatenate."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@staticmethod
def symbolic(graph, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
# -----------------
......@@ -143,17 +265,40 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
# -----------------
def copy_to_tensor_model_parallel_region(input_):
def copy_to_tensor_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
return _CopyToModelParallelRegion.apply(input_)
def reduce_from_tensor_model_parallel_region(input_):
def reduce_from_tensor_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_tensor_model_parallel_region(input_):
def scatter_to_tensor_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
return _ScatterToModelParallelRegion.apply(input_)
def gather_from_tensor_model_parallel_region(input_):
def gather_from_tensor_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
return _GatherFromModelParallelRegion.apply(input_)
def scatter_to_sequence_parallel_region(input_: torch.Tensor) -> torch.Tensor:
return _ScatterToSequenceParallelRegion.apply(input_)
def gather_from_sequence_parallel_region(input_: torch.Tensor, to_model_parallel: bool = True) -> torch.Tensor:
return _GatherFromSequenceParallelRegion.apply(input_, to_model_parallel)
def reduce_scatter_to_sequence_parallel_region(input_: torch.Tensor) -> torch.Tensor:
return _ReduceScatterToSequenceParallelRegion.apply(input_)
__all__ = [
"copy_to_tensor_model_parallel_region",
"reduce_from_tensor_model_parallel_region",
"scatter_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region",
"gather_from_sequence_parallel_region",
"reduce_scatter_to_sequence_parallel_region",
]
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,6 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(mkozuki): Remove this file as Megatron-LM seems to have done so.
import torch
......@@ -49,13 +52,20 @@ class MemoryBuffer:
element_size = torch.tensor([], dtype=dtype).element_size()
print(
"> building the {} memory buffer with {} num elements "
"and {} dtype ({:.1f} MB)...".format(name, numel, dtype, numel * element_size / 1024 / 1024),
"and {} dtype ({:.1f} MB)...".format(
name, numel, dtype, numel * element_size / 1024 / 1024
),
flush=True,
)
self.name = name
self.numel = numel
self.dtype = dtype
self.data = torch.empty(self.numel, dtype=self.dtype, device=torch.cuda.current_device(), requires_grad=False)
self.data = torch.empty(
self.numel,
dtype=self.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
# Index tracking the start of the free memory.
self._start = 0
......@@ -81,13 +91,17 @@ class MemoryBuffer:
def add(self, tensor):
"""Allocate a chunk of memory from the buffer to tensor and copy
the values."""
assert tensor.dtype == self.dtype, "Input tensor type {} different from buffer type {}".format(
assert (
tensor.dtype == self.dtype
), "Input tensor type {} different from buffer type {}".format(
tensor.dtype, self.dtype
)
# Number of elements of the input tensor.
tensor_numel = torch.numel(tensor)
new_start = self._start + tensor_numel
assert new_start <= self.numel, "Not enough memory left in the buffer ({} > {})".format(
assert (
new_start <= self.numel
), "Not enough memory left in the buffer ({} > {})".format(
tensor_numel, self.numel - self._start
)
# New tensor is a view into the memory.
......@@ -124,7 +138,8 @@ class RingMemBuffer:
def __init__(self, name, num_buffers, numel, dtype, track_usage):
self.num_buffers = num_buffers
self.buffers = [
allocate_mem_buff(name + " {}".format(i), numel, dtype, track_usage) for i in range(num_buffers)
allocate_mem_buff(name + " {}".format(i), numel, dtype, track_usage)
for i in range(num_buffers)
]
self._index = -1
......
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,9 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO (mkozuki): Audit this file.
# I don't think some functions strongly relate to `random` in tensor_parallel.
# Rather, some functions are mainly for gradient checkpointing (torch.utils.checkpoint).
# NOTE(mkozuki): This file is based on megatron-lm/mpu/random.py with some differences:
# - Not using "viewless" tensor:
# - _kernel_make_viewless_tensor
# - MakeViewlessTensor
# - make_viewless_tensor
# - assert_viewless_tensor
# - safely_set_viewless_tensor_data
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
......@@ -35,13 +39,12 @@ from apex.transformer.utils import gather_split_1d_tensor
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = "model-parallel-rng"
# TODO(mkozuki): Remove `_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER` as megatron-lm doesn't seem to use.
# Whether apply model parallelism to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None
# TODO (mkozuki): Consider the possibility of removing `tensor_model_parallel_size`,
# `get_tensor_model_parallel_world_size()` might be alternative.
# TODO(mkozuki): Remove `init_checkpointed_activations_memory_buffer` as megatron-lm doesn't seem to use.
def init_checkpointed_activations_memory_buffer(
micro_batch_size,
max_position_embeddings,
......@@ -53,8 +56,15 @@ def init_checkpointed_activations_memory_buffer(
):
"""Initializ the memory buffer for the checkpointed activations."""
per_layer = micro_batch_size * max_position_embeddings * hidden_size // tensor_model_parallel_size
assert num_layers % checkpoint_num_layers == 0, "number of layers is not divisible by checkpoint-num-layers"
per_layer = (
micro_batch_size
* max_position_embeddings
* hidden_size
// tensor_model_parallel_size
)
assert (
num_layers % checkpoint_num_layers == 0
), "number of layers is not divisible by checkpoint-num-layers"
num_checkpointer_layers = num_layers // checkpoint_num_layers
numel = per_layer * num_checkpointer_layers
dtype = torch.half
......@@ -70,6 +80,7 @@ def init_checkpointed_activations_memory_buffer(
)
# TODO(mkozuki): Remove `reset_checkpointed_activations_memory_buffer` as megatron-lm doesn't seem to use.
def reset_checkpointed_activations_memory_buffer():
"""Reset the memory used for checkpointing."""
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
......@@ -79,7 +90,7 @@ def reset_checkpointed_activations_memory_buffer():
def _set_cuda_rng_state(new_state, device=-1):
"""Sets the random number generator state of the current GPU.
Argumentss:
Arguments:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
......@@ -217,7 +228,9 @@ def model_parallel_cuda_manual_seed(seed):
# Set the default state.
torch.cuda.manual_seed(data_parallel_seed)
# and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed)
_CUDA_RNG_STATE_TRACKER.add(
_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed
)
# TODO (mkozuki): Move the below gradient checkpoint related features to another (new) file.
......@@ -230,8 +243,9 @@ class CheckpointFunction(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, run_function, *args):
def forward(ctx, run_function, distribute_saved_activations, *args):
ctx.run_function = run_function
ctx.distribute_saved_activations = distribute_saved_activations
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
......@@ -243,10 +257,8 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
ctx.input_0_shape = args[0].data.shape
args[0].data = split_tensor_into_1d_equal_chunks(args[0].data)
args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add(args[0].data)
if ctx.distribute_saved_activations:
ctx.input_0_shape = args[0].shape
# Store everything.
ctx.save_for_backward(*args)
......@@ -255,11 +267,11 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), " "please use .backward() if possible")
raise RuntimeError(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
inputs = ctx.saved_tensors
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
inputs[0].data = gather_split_1d_tensor(inputs[0].data)
inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
......@@ -284,11 +296,16 @@ class CheckpointFunction(torch.autograd.Function):
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
torch.autograd.backward(outputs, args)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
return (None,) + grads
grads = tuple(
inp.grad if isinstance(inp, torch.Tensor) else inp
for inp in detached_inputs
)
return (None, None) + grads
def checkpoint(function, *args):
# NOTE(mkozuki): It doesn't look like `distribute_saved_activations` is used in apex.transformer
# but I added this change to reduce the superficial difference from Megatron-LM.
def checkpoint(function, distribute_saved_activations, *args):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
return CheckpointFunction.apply(function, *args)
return CheckpointFunction.apply(function, distribute_saved_activations, *args)
......@@ -12,12 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Sequence
import torch
from apex.transformer.utils import divide
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]:
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
......@@ -43,12 +49,16 @@ class VocabUtility:
partition: Note that indices in [fist, last)"""
@staticmethod
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
def vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size: int, rank, world_size: int
) -> Sequence[int]:
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
@staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size
)
......@@ -39,9 +39,13 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_data_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_biencoder_args(parser)
parser = _add_vit_args(parser)
parser = _add_vision_args(parser)
parser = _add_logging_args(parser)
# NOTE(mkozuki): This option is added to investigate the potential of `torch.autograd.graph.save_on_cpu()`.
# ref: https://pytorch.org/docs/stable/autograd.html#torch.autograd.graph.save_on_cpu.
parser.add_argument('--cpu-offload', action='store_true', default=False, help='Turns on CPU offloading')
# Custom arguments.
if extra_args_provider is not None:
parser = extra_args_provider(parser)
......@@ -65,6 +69,11 @@ def parse_args(extra_args_provider=None, defaults={},
args.pipeline_model_parallel_size = min(
args.pipeline_model_parallel_size,
(args.world_size // args.tensor_model_parallel_size))
args.transformer_pipeline_model_parallel_size = (
args.pipeline_model_parallel_size - 1
if args.standalone_embedding_stage else
args.pipeline_model_parallel_size
)
# Checks.
model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size
......@@ -98,13 +107,18 @@ def parse_args(extra_args_provider=None, defaults={},
'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size
if args.checkpoint_activations:
args.activations_checkpoint_method = 'uniform'
args.recompute_granularity = 'full'
args.recompute_method = 'uniform'
if args.rank == 0:
print('--checkpoint-activations is no longer valid, '
'use --activation-checkpoint-method instead. '
'Defaulting to activation-checkpoint-method=uniform.')
'use --recompute-granularity and --recompute-method instead. '
'Defaulting to recompute-granularity=full and recompute-method=uniform.')
del args.checkpoint_activations
if args.recompute_activations:
args.recompute_granularity = 'selective'
del args.recompute_activations
# Set input defaults.
for key in defaults:
# For default to be valid, it should not be provided in the
......@@ -166,6 +180,14 @@ def parse_args(extra_args_provider=None, defaults={},
if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local'
assert args.use_contiguous_buffers_in_local_ddp
else:
if args.gradient_accumulation_fusion:
args.gradient_accumulation_fusion = False
if args.rank == 0:
print('Gradient accumulation fusion to linear layer weight '
'gradient computation is supported only with fp32 '
'gradient accumulation. Setting gradient_accumulation_fusion '
'to False', flush=True)
# For torch DDP, we do not use contiguous buffer
if args.DDP_impl == 'torch':
......@@ -244,17 +266,51 @@ def parse_args(extra_args_provider=None, defaults={},
if args.fp32_residual_connection:
assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing.
if args.distribute_checkpointed_activations:
if args.weight_decay_incr_style == 'constant':
assert args.start_weight_decay is None
assert args.end_weight_decay is None
args.start_weight_decay = args.weight_decay
args.end_weight_decay = args.weight_decay
else:
assert args.start_weight_decay is not None
assert args.end_weight_decay is not None
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
# Persistent fused layer norm.
if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11):
args.no_persist_layer_norm = True
if args.rank == 0:
print('Persistent fused layer norm kernel is supported from '
'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
'Defaulting to no_persist_layer_norm=True')
# Activation recomputing.
if args.distribute_saved_activations:
assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'checkpointed activations only across tensor model ' \
'recomputed activations only across tensor model ' \
'parallel groups'
assert args.activations_checkpoint_method is not None, \
'for distribute-checkpointed-activations to work you '\
'need to use a activation-checkpoint method '
assert args.num_layers_per_virtual_pipeline_stage is None, \
'currently distrobuted checkpoint activations only supported for ' \
'nointerleaved pipeline parallelism'
assert args.recompute_granularity == 'full', \
'distributed recompute activations is only '\
'application to full recompute granularity'
assert args.recompute_method is not None, \
'for distributed recompute activations to work you '\
'need to use a recompute method '
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \
'distributed recompute activations are supported for pytorch ' \
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
if args.recompute_granularity == 'selective':
assert args.recompute_method is None, \
'recompute method is not yet supported for ' \
'selective recomputing granularity'
# disable async_tensor_model_parallel_allreduce when
# model parallel memory optimization is enabled
if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False
_print_args(args)
return args
......@@ -279,6 +335,18 @@ def _check_arg_is_not_none(args, arg):
assert getattr(args, arg) is not None, '{} argument is None'.format(arg)
def _add_inference_args(parser):
group = parser.add_argument_group(title='inference')
group.add_argument('--inference-batch-times-seqlen-threshold',
type=int, default=512,
help='During inference, if batch-size times '
'sequence-length is smaller than this threshold '
'then we will not use pipelining, otherwise we will.')
return parser
def _add_network_size_args(parser):
group = parser.add_argument_group(title='network size')
......@@ -318,6 +386,8 @@ def _add_network_size_args(parser):
group.add_argument('--bert-no-binary-head', action='store_false',
help='Disable BERT binary head.',
dest='bert_binary_head')
group.add_argument('--num-experts', type=int, default=None,
help='Number of Experts in Switch Transformer (None means no Switch)')
return parser
......@@ -354,6 +424,9 @@ def _add_logging_args(parser):
group.add_argument('--log-memory-to-tensorboard',
action='store_true',
help='Enable memory logging to tensorboard.')
group.add_argument('--log-world-size-to-tensorboard',
action='store_true',
help='Enable world size logging to tensorboard.')
return parser
......@@ -367,6 +440,13 @@ def _add_regularization_args(parser):
help='Dropout probability for hidden state transformer.')
group.add_argument('--weight-decay', type=float, default=0.01,
help='Weight decay coefficient for L2 regularization.')
group.add_argument('--start-weight-decay', type=float,
help='Initial weight decay coefficient for L2 regularization.')
group.add_argument('--end-weight-decay', type=float,
help='End of run weight decay coefficient for L2 regularization.')
group.add_argument('--weight-decay-incr-style', type=str, default='constant',
choices=['constant', 'linear', 'cosine'],
help='Weight decay increment function.')
group.add_argument('--clip-grad', type=float, default=1.0,
help='Gradient clipping based on global L2 norm.')
group.add_argument('--adam-beta1', type=float, default=0.9,
......@@ -413,27 +493,40 @@ def _add_training_args(parser):
' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.')
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
group.add_argument('--recompute-activations', action='store_true',
help='recompute activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--distribute-checkpointed-activations',
group.add_argument('--recompute-granularity', type=str, default=None,
choices=['full', 'selective'],
help='Checkpoint activations to allow for training '
'with larger models, sequences, and batch sizes. '
'It is supported at two granularities 1) full: '
'whole transformer layer is recomputed, '
'2) selective: core attention part of the transformer '
'layer is recomputed.')
group.add_argument('--distribute-saved-activations',
action='store_true',
help='If set, distribute checkpointed activations '
help='If set, distribute recomputed activations '
'across model parallel group.')
group.add_argument('--activations-checkpoint-method', type=str, default=None,
group.add_argument('--recompute-method', type=str, default=None,
choices=['uniform', 'block'],
help='1) uniform: uniformly divide the total number of '
'Transformer layers and checkpoint the input activation of '
'each divided chunk, '
'2) checkpoint the input activations of only a set number of '
'Transformer layers and recompute the input activation of '
'each divided chunk at specified granularity, '
'2) recompute the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'rest without any checkpointing'
'default) do not apply activations checkpoint to any layers')
group.add_argument('--activations-checkpoint-num-layers', type=int, default=1,
'rest without any recomputing at specified granularity'
'default) do not apply activations recompute to any layers')
group.add_argument('--recompute-num-layers', type=int, default=1,
help='1) uniform: the number of Transformer layers in each '
'uniformly divided checkpoint unit, '
'uniformly divided recompute unit, '
'2) block: the number of individual Transformer layers '
'to checkpoint within each pipeline stage.')
'to recompute within each pipeline stage.')
# deprecated
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all '
'training runs. Note that either train-iters or '
......@@ -472,7 +565,20 @@ def _add_training_args(parser):
action='store_true',
help='Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.')
'gradient compuation of a column-linear layer.',
dest='async_tensor_model_parallel_allreduce')
group.add_argument('--no-persist-layer-norm', action='store_true',
help='Disable using persistent fused layer norm kernel. '
'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden '
'size is supported.')
group.add_argument('--sequence-parallel', action='store_true',
help='Enable sequence parallel optimization.')
group.add_argument('--no-gradient-accumulation-fusion',
action='store_false',
help='Disable fusing gradient accumulation to weight '
'gradient computation of linear layers',
dest='gradient_accumulation_fusion')
return parser
......@@ -645,6 +751,11 @@ def _add_distributed_args(parser):
help='Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.')
group.add_argument('--standalone-embedding-stage', action='store_true',
default=False, help='If set, *input* embedding layer '
'is placed on its own pipeline stage, without any '
'transformer layers. (For T5, this flag currently only '
'affects the encoder embedding.)')
return parser
......@@ -791,16 +902,70 @@ def _add_biencoder_args(parser):
return parser
def _add_vit_args(parser):
group = parser.add_argument_group(title="vit")
def _add_vision_args(parser):
group = parser.add_argument_group(title="vision")
# general vision arguments
group.add_argument('--num-classes', type=int, default=1000,
help='num of classes in vision classificaiton task')
group.add_argument('--img-dim', type=int, default=224,
help='Image size for vision classification task')
group.add_argument('--img-h', type=int, default=224,
help='Image height for vision classification task')
group.add_argument('--img-w', type=int, default=224,
help='Image height for vision classification task')
group.add_argument('--num-channels', type=int, default=3,
help='Number of channels in input image data')
group.add_argument('--patch-dim', type=int, default=16,
help='patch dimension used in vit')
help='patch dimension')
group.add_argument('--classes-fraction', type=float, default=1.0,
help='training with fraction of classes.')
group.add_argument('--data-per-class-fraction', type=float, default=1.0,
help='training with fraction of data per class.')
group.add_argument('--no-data-sharding', action='store_false',
help='Disable data sharding.',
dest='data_sharding')
group.add_argument('--head-lr-mult', type=float, default=1.0,
help='learning rate multiplier for head during finetuning')
# pretraining type and backbone selection`
group.add_argument('--vision-pretraining', action='store_true',
help='flag to indicate vision pretraining')
group.add_argument('--vision-pretraining-type', type=str, default='classify',
choices=['classify', 'inpaint', 'dino'],
help='pretraining objectives')
group.add_argument('--vision-backbone-type', type=str, default='vit',
choices=['vit', 'mit', 'swin'],
help='backbone types types')
group.add_argument('--swin-backbone-type', type=str, default='tiny',
choices=['tiny', 'base', 'h3'],
help='pretraining objectives')
# inpainting arguments
group.add_argument('--mask-type', type=str, default='random',
choices=['random', 'row'],
help='mask types')
group.add_argument('--mask-factor', type=float, default=1.0,
help='mask size scaling parameter')
# dino arguments
group.add_argument('--iter-per-epoch', type=int, default=1250,
help='iterations per epoch')
group.add_argument('--dino-local-img-size', type=int, default=96,
help='Image size for vision classification task')
group.add_argument('--dino-local-crops-number', type=int, default=10,
help='Number of local crops')
group.add_argument('--dino-head-hidden-size', type=int, default=2048,
help='Hidden dimension size in dino head')
group.add_argument('--dino-bottleneck-size', type=int, default=256,
help='Bottle neck dimension in dino head ')
group.add_argument('--dino-freeze-last-layer', type=float, default=1,
help='Freezing last layer weights')
group.add_argument('--dino-norm-last-layer', action='store_true',
help='Disable Norm in last layer.')
group.add_argument('--dino-warmup-teacher-temp', type=float, default=0.04,
help='warump teacher temperature')
group.add_argument('--dino-teacher-temp', type=float, default=0.07,
help='teacher temperature')
group.add_argument('--dino-warmup-teacher-temp-epochs', type=int, default=30,
help='warmup teacher temperaure epochs')
return parser
......@@ -12,15 +12,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import datetime
import os
import random
from typing import Optional, Union, List
from typing import Optional, Union, List, Tuple, Callable, Dict
import numpy
import torch
import torch.nn as nn
from apex import transformer
from apex.transformer.tensor_parallel import(
ColumnParallelLinear,
RowParallelLinear,
scatter_to_sequence_parallel_region,
)
from apex.transformer.pipeline_parallel.utils import (
average_losses_across_data_parallel_group,
)
from apex.transformer.pipeline_parallel.schedules.common import (
Batch,
)
from apex.transformer.testing import global_vars
......@@ -29,7 +42,6 @@ TEST_SUCCESS_MESSAGE = ">> passed the test :-)"
# note (mkozuki): `pre_process` and `post_process` are a placeholder until interleaving schedule test comes.
class MyLayer(nn.Module):
def __init__(self, hidden_size: int, pre_process: bool, post_process: bool):
super().__init__()
self.pre_process = pre_process
......@@ -39,17 +51,28 @@ class MyLayer(nn.Module):
def forward(self, x):
return self.layer(x)
class MyModel(nn.Module):
def __init__(self, hidden_size: int, pre_process: bool = False, post_process: bool = False) -> None:
class MyModel(nn.Module):
def __init__(
self,
hidden_size: int, pre_process: bool = False, post_process: bool = False,
*,
add_encoder: bool = False, add_decoder: bool = False,
) -> None:
super().__init__()
self.pre_process = pre_process
self.post_process = post_process
self.layer = MyLayer(hidden_size=hidden_size, pre_process=pre_process, post_process=post_process)
self.layer = MyLayer(
hidden_size=hidden_size, pre_process=pre_process, post_process=post_process
)
self.input_tensor = None
def set_input_tensor(self, input_tensor: Union[torch.Tensor, List[torch.Tensor]]) -> None:
self.input_tensor = input_tensor
def set_input_tensor(
self, input_tensor: Union[torch.Tensor, List[torch.Tensor]]
) -> None:
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
self.input_tensor = input_tensor[0]
def forward(self, x: Optional[torch.Tensor]) -> torch.Tensor:
if self.input_tensor is None:
......@@ -57,8 +80,154 @@ class MyModel(nn.Module):
return self.layer(self.input_tensor)
def model_provider_func(hidden_size, pre_process, post_process) -> MyModel:
return MyModel(hidden_size, pre_process, post_process)
class ToyParallelMLP(nn.Module):
def __init__(
self,
hidden_size: int, pre_process: bool = False, post_process: bool = False,
*,
sequence_parallel_enabled: bool = False,
# TODO(mkozuki): Support these two?
add_encoder: bool = False, add_decoder: bool = False,
) -> None:
super().__init__()
self.pre_process = pre_process
self.post_process = post_process
self.sequence_parallel_enabled = sequence_parallel_enabled
ffn_hidden_size = 4 * hidden_size
self.dense_h_to_4h = ColumnParallelLinear(
hidden_size,
ffn_hidden_size,
gather_output=False,
# init_method=init_method,
skip_bias_add=True,
# use_cpu_initialization=use_cpu_initialization,
bias=True,
sequence_parallel_enabled=sequence_parallel_enabled,
no_async_tensor_model_parallel_allreduce=True,
)
self.dense_4h_to_h = RowParallelLinear(
ffn_hidden_size,
hidden_size,
input_is_parallel=True,
# init_method=output_layer_init_method,
skip_bias_add=False,
# use_cpu_initialization=use_cpu_initialization,
bias=True,
sequence_parallel_enabled=sequence_parallel_enabled,
)
self.activation_func = torch.nn.GELU()
def set_input_tensor(
self,
input_tensor: Union[torch.Tensor, List[torch.Tensor]],
) -> None:
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
self.input_tensor = input_tensor[0]
def forward(
self,
x: Optional[torch.Tensor],
) -> torch.Tensor:
"""Forward of Simplified ParallelMLP.
Args:
x: :obj:`None` if pipeline rank != pippeline first rank. When :obj:`None`,
`self.input_tensor` is taken care of by `forward_step` defined in
apex/transformer/pipeline_parallel/schedules/common.py
"""
# [s, b, h]
if self.input_tensor is None:
input = x
else:
input = self.input_tensor
intermediate_parallel, bias_parallel = self.dense_h_to_4h(input)
if bias_parallel is not None:
intermediate_parallel += bias_parallel
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output
def model_provider_func(
hidden_size: int,
pre_process: bool,
post_process: bool,
*,
add_encoder: bool = False,
add_decoder: bool = False) -> MyModel:
return MyModel(hidden_size, pre_process, post_process, add_encoder=add_encoder, add_decoder=add_decoder)
def mlp_provider_func(
hidden_size: int,
pre_process: bool,
post_process: bool,
*,
add_encoder: bool = False,
add_decoder: bool = False,
sequence_parallel_enabled: bool = False,
) -> ToyParallelMLP:
return ToyParallelMLP(
hidden_size,
pre_process,
post_process,
add_encoder=add_encoder,
add_decoder=add_decoder,
sequence_parallel_enabled=sequence_parallel_enabled,
)
def process_batch(batch):
if isinstance(batch, list):
x = batch[0]
else:
x = batch
return x
def fwd_step_func(batch, model):
x = process_batch(batch)
y = model(x)
# note (mkozuki): I don't think this function is nice but I do think this is enough for now
# just to check the sanity of ported pipeline functions.
def loss_func(x):
loss = torch.sum(x)
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"avg": averaged_loss}
return y, loss_func
@dataclass(frozen=True)
class ToyParallelMLPFwdBwdStepFunc:
sequence_parallel_enabled: bool
def __call__(
self,
batch: Batch,
model: torch.nn.Module,
) -> Tuple[torch.Tensor, Callable[[torch.Tensor], Tuple[torch.Tensor, Dict[str, torch.Tensor]]]]:
x = batch[0] if isinstance(batch, list) else batch
if isinstance(x, torch.Tensor):
x = x.transpose(0, 1).contiguous()
if self.sequence_parallel_enabled:
x = scatter_to_sequence_parallel_region(x)
y = model(x)
# note (mkozuki): I don't think this function is nice but I do think this is enough for now
# just to check the sanity of ported pipeline functions.
def loss_func(x):
loss = torch.sum(x)
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"avg": averaged_loss}
return y, loss_func
class IdentityLayer(torch.nn.Module):
......@@ -78,22 +247,28 @@ def set_random_seed(seed):
transformer.tensor_parallel.model_parallel_cuda_manual_seed(seed)
def initialize_distributed(backend='nccl'):
def initialize_distributed(backend="nccl"):
"""Initialize torch.distributed."""
# Get local rank in case it is provided.
# parser = argparse.ArgumentParser()
# parser.add_argument('--local_rank', type=int, default=None,
# help='local rank passed from distributed launcher')
# args = parser.parse_args()
if backend not in ("nccl", "ucc"):
raise RuntimeError(f"Currently only nccl & ucc are supported but {backend}")
if backend == "ucc":
import torch_ucc # NOQA
args = global_vars.get_args()
local_rank = args.local_rank
# Get rank and world size.
rank = int(os.getenv('RANK', '0'))
world_size = int(os.getenv("WORLD_SIZE", '1'))
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
print('> initializing torch.distributed with local rank: {}, '
'rank: {}, world size: {}'.format(local_rank, rank, world_size))
print(
"> initializing torch.distributed with local rank: {}, "
"rank: {}, world size: {}".format(local_rank, rank, world_size)
)
# Set the device id.
device = rank % torch.cuda.device_count()
......@@ -102,22 +277,21 @@ def initialize_distributed(backend='nccl'):
torch.cuda.set_device(device)
# Call the init process.
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
init_method = "tcp://"
master_ip = os.getenv("MASTER_ADDR", "localhost")
master_port = os.getenv("MASTER_PORT", "6000")
init_method += master_ip + ":" + master_port
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
init_method=init_method)
backend=backend, world_size=world_size, rank=rank, init_method=init_method,
timeout=datetime.timedelta(seconds=60),
)
def print_separator(message):
torch.distributed.barrier()
filler_len = (78 - len(message)) // 2
filler = '-' * filler_len
string = '\n' + filler + ' {} '.format(message) + filler
filler = "-" * filler_len
string = "\n" + filler + " {} ".format(message) + filler
if torch.distributed.get_rank() == 0:
print(string, flush=True)
torch.distributed.barrier()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment