Unverified Commit db92ee13 authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #64 from ROCmSoftwarePlatform/IFU-master-2021-12-08

IFU-master-2021-12-08
parents d150afdc 68364b49
from apex.transformer.amp.grad_scaler import GradScaler
__all__ = [
"GradScaler",
]
from collections import defaultdict
import torch
from apex.transformer import parallel_state
class GradScaler(torch.cuda.amp.GradScaler):
"""
Gradient scaler for model-parallel inf check. The inf in gradients are checked across tensor-parallel
ranks in (1) executing optimizer step and (2) gradient scaler update.
"""
def __init__(
self, init_scale=2.0 ** 16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True
):
super().__init__(
init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
enabled=enabled,
)
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
retval = None
found_inf = torch.cuda.FloatTensor([sum(v.item() for v in optimizer_state["found_inf_per_device"].values())])
# Update across all model parallel instances.
torch.distributed.all_reduce(
found_inf, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group()
)
if found_inf.item() == 0:
retval = optimizer.step(*args, **kwargs)
return retval
def update(self, new_scale=None):
"""
Updates the scale factor.
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
the scale is multiplied by ``growth_factor`` to increase it.
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
used directly, it's used to fill GradScaler's internal scale tensor. So if
``new_scale`` was a tensor, later in-place changes to that tensor will not further
affect the scale GradScaler uses internally.)
Args:
new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor.
.. warning::
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
been invoked for all optimizers used this iteration.
"""
if not self._enabled:
return
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
if new_scale is not None:
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr]
else:
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined]
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr]
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs = [
found_inf.to(device=_scale.device, non_blocking=True)
for state in self._per_optimizer_states.values()
for found_inf in state["found_inf_per_device"].values()
]
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
found_inf_combined = found_infs[0]
# Update across all model parallel instances.
torch.distributed.all_reduce(
found_inf_combined, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group()
)
if len(found_infs) > 1:
for i in range(1, len(found_infs)):
found_inf = found_infs[i]
# Update across all model parallel instances.
torch.distributed.all_reduce(
found_inf, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group()
)
found_inf_combined += found_inf
torch._amp_update_scale_(
_scale,
_growth_tracker,
found_inf_combined,
self._growth_factor,
self._backoff_factor,
self._growth_interval,
)
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(torch.cuda.amp.grad_scaler._refresh_per_optimizer_state)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
from .fused_softmax import FusedScaleMaskSoftmax
from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax
__all__ = [
"FusedScaleMaskSoftmax",
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -15,7 +15,7 @@
import torch
from apex._autocast_utils import _cast_if_autocast_enabled
from ..enums import AttnMaskType
from apex.transformer.enums import AttnMaskType
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
......
from typing import Optional
import logging
import os
import threading
def get_transformer_logger(name: str) -> logging.Logger:
name_wo_ext = os.path.splitext(name)[0]
return logging.getLogger(name_wo_ext)
def set_logging_level(verbosity) -> None:
"""Change logging severity.
Args:
verbosity
"""
from apex import _library_root_logger
_library_root_logger.setLevel(verbosity)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -15,35 +15,41 @@
"""Megatron number of micro-batches calculators."""
from abc import ABC
from abc import abstractmethod
from typing import Optional, List
def build_num_microbatches_calculator(args):
def build_num_microbatches_calculator(
rank: int,
rampup_batch_size: Optional[List[int]],
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
):
# Constant num micro-batches.
if args.rampup_batch_size is None:
if rampup_batch_size is None:
num_microbatches_calculator = ConstantNumMicroBatches(
args.global_batch_size, args.micro_batch_size, args.data_parallel_size
global_batch_size, micro_batch_size, data_parallel_size
)
if args.rank == 0:
if rank == 0:
print(
"setting number of micro-batches to constant {}".format(num_microbatches_calculator.get()), flush=True
)
else:
assert len(args.rampup_batch_size) == 3, (
assert len(rampup_batch_size) == 3, (
"expected the following "
"format: --rampup-batch-size <start batch size> "
"<batch size incerement> <ramp-up samples>"
)
start_batch_size = int(args.rampup_batch_size[0])
batch_size_increment = int(args.rampup_batch_size[1])
ramup_samples = int(args.rampup_batch_size[2])
if args.rank == 0:
start_batch_size = int(rampup_batch_size[0])
batch_size_increment = int(rampup_batch_size[1])
ramup_samples = int(rampup_batch_size[2])
if rank == 0:
print(
"will use batch size rampup starting from global batch "
"size {} to global batch size {} with batch size increments "
"{} over {} samples.".format(
start_batch_size, args.global_batch_size, batch_size_increment, ramup_samples
start_batch_size, global_batch_size, batch_size_increment, ramup_samples
),
flush=True,
)
......@@ -51,9 +57,9 @@ def build_num_microbatches_calculator(args):
start_batch_size,
batch_size_increment,
ramup_samples,
args.global_batch_size,
args.micro_batch_size,
args.data_parallel_size,
global_batch_size,
micro_batch_size,
data_parallel_size,
)
return num_microbatches_calculator
......@@ -86,6 +92,8 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator):
assert self.num_micro_batches >= 1
self.current_global_batch_size = global_batch_size
self.micro_batch_size = micro_batch_size
def update(self, consumed_samples, consistency_check):
pass
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,11 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model and data parallel groups."""
from typing import Tuple
import torch
# TODO (mkozuki): Consider dissecting utils as this utils import is here
# only for ensure_divisibility
from .tensor_parallel import utils
from apex.transformer.utils import ensure_divisibility
# Intra-layer model parallel group that the current rank belongs to.
......@@ -40,6 +42,9 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage
_PIPELINE_GLOBAL_RANKS = None
......@@ -76,23 +81,26 @@ def initialize_model_parallel(
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
if torch.distributed.get_rank() == 0:
print("> initializing tensor model parallel with size {}".format(tensor_model_parallel_size_))
print("> initializing pipeline model parallel with size {}".format(pipeline_model_parallel_size_))
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)
# TODO (mkozuki): Consider moving `ensure_divisibility` to this file.
utils.ensure_divisibility(world_size, tensor_model_parallel_size * pipeline_model_parallel_size)
ensure_divisibility(world_size, tensor_model_parallel_size * pipeline_model_parallel_size)
data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)
if torch.distributed.get_rank() == 0:
print("> initializing tensor model parallel with size {}".format(tensor_model_parallel_size))
print("> initializing pipeline model parallel with size {}".format(pipeline_model_parallel_size))
print("> initializing data parallel with size {}".format(data_parallel_size))
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
num_data_parallel_groups = world_size // data_parallel_size
if virtual_pipeline_model_parallel_size_ is not None:
assert pipeline_model_parallel_size_ > 2, \
'pipeline-model-parallel size should be greater than 2 with ' \
'interleaved schedule'
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
......@@ -138,6 +146,7 @@ def initialize_model_parallel(
global _PIPELINE_GLOBAL_RANKS
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, "pipeline model parallel group is already initialized"
global _EMBEDDING_GROUP
global _EMBEDDING_GLOBAL_RANKS
assert _EMBEDDING_GROUP is None, "embedding group is already initialized"
for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
......@@ -154,6 +163,19 @@ def initialize_model_parallel(
group = torch.distributed.new_group(embedding_ranks)
if rank in embedding_ranks:
_EMBEDDING_GROUP = group
if rank in ranks:
_EMBEDDING_GLOBAL_RANKS = embedding_ranks
def get_rank_info() -> Tuple[int, int, int]:
"""Returns a tuple of (tensor, pipeline, data)-parallel-rank for logger."""
if model_parallel_is_initialized():
return (
get_tensor_model_parallel_rank(),
get_pipeline_model_parallel_rank(),
# get_virtual_pipeline_model_parallel_rank(),
get_data_parallel_rank(),
)
return (0, 0, 0)
def model_parallel_is_initialized():
......@@ -193,6 +215,22 @@ def get_embedding_group():
return _EMBEDDING_GROUP
def is_rank_in_embedding_group(ignore_virtual=False):
"""Return true if current rank is in embedding group, False otherwise."""
rank = torch.distributed.get_rank()
global _EMBEDDING_GLOBAL_RANKS
if ignore_virtual:
return rank in _EMBEDDING_GLOBAL_RANKS
if rank in _EMBEDDING_GLOBAL_RANKS:
if rank == _EMBEDDING_GLOBAL_RANKS[0]:
return is_pipeline_first_stage(ignore_virtual=False)
elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
return is_pipeline_last_stage(ignore_virtual=False)
else:
return True
return False
def set_tensor_model_parallel_world_size(world_size):
"""Set the tensor model parallel size"""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
......@@ -344,3 +382,15 @@ def destroy_model_parallel():
_DATA_PARALLEL_GROUP = None
global _EMBEDDING_GROUP
_EMBEDDING_GROUP = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
from apex.transformer.pipeline_parallel.schedules import get_forward_backward_func
from apex.transformer.pipeline_parallel.schedules.common import build_model
__all__ = [
"get_forward_backward_func",
"build_model",
]
import time
import torch
class _Timer:
"""Timer."""
def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
def start(self):
"""Start the timer."""
assert not self.started_, "timer has already been started"
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, "timer is not started"
torch.cuda.synchronize()
self.elapsed_ += time.time() - self.start_time
self.started_ = False
def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False
def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
return elapsed_
class _Timers:
"""Group of timers."""
def __init__(self):
self.timers = {}
def __call__(self, name):
if name not in self.timers:
self.timers[name] = _Timer(name)
return self.timers[name]
def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(name + "-time", value, iteration)
def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers."""
assert normalizer > 0.0
string = "time (ms)"
for name in names:
elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer
string += " | {}: {:.2f}".format(name, elapsed_time)
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1):
print(string, flush=True)
else:
print(string, flush=True)
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import reduce
import operator
from typing import Union, Optional, Tuple
import warnings
import torch
from apex._autocast_utils import _get_current_dtype
from apex.transformer import parallel_state
from apex.transformer.utils import split_tensor_into_1d_equal_chunks
from apex.transformer.utils import gather_split_1d_tensor
from apex.transformer.pipeline_parallel.utils import Shape
from apex.transformer.pipeline_parallel._timers import _Timers
def _run_p2pops(
tensor_send_prev: Union[torch.Tensor, None],
tensor_send_next: Union[torch.Tensor, None],
tensor_recv_prev: Union[torch.Tensor, None],
tensor_recv_next: Union[torch.Tensor, None],
):
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor_send_prev,
parallel_state.get_pipeline_model_parallel_prev_rank(),
)
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_prev,
parallel_state.get_pipeline_model_parallel_prev_rank(),
)
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor_send_next,
parallel_state.get_pipeline_model_parallel_next_rank(),
)
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_next,
parallel_state.get_pipeline_model_parallel_next_rank(),
)
ops.append(recv_next_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
def _communicate(
tensor_send_next: Optional[torch.Tensor],
tensor_send_prev: Optional[torch.Tensor],
recv_prev: bool,
recv_next: bool,
tensor_shape: Optional[Shape] = None,
override_scatter_gather_tensors_in_pipeline: bool = False,
dtype_: torch.dtype = torch.float,
*,
scatter_gather_tensors_in_pipeline: bool = True,
params_dtype: Optional[torch.dtype] = None,
fp32_residual_connection: bool = False,
) -> Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None]]:
"""Base function for communication of tensors between stages.
Args:
tensor_send_next: tensor to send to next rank (no tensor sent if set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if set to None).
recv_prev: boolean for whether tensor should be received from previous rank.
recv_next: boolean for whether tensor should be received from next rank.
tensor_shape: optional, use when the input sequence contains less tokens than the default sequence length
override_scatter_gather_tensors_in_pipeline:
optional, this is used when tensor_shape is provided to override scatter gather tensors
dtype_: This is used when tensor_shape is provided and what is the type of tensor_shape
Keyword args:
scatter_gather_tensors_in_pipeline: Optional. If :obj:`True`, use scatter/gather to optimize communication of tensors.
params_dtype: Optional and legacy. Defaults to torch.float. If you manually call `.half()` or `.bfloat16()` on
your model deliberately, pass this argument.
fp32_residual_connection: Optional. If :obj:`True`, move residual connections to fp32.
Returns:
tuple containing
- tensor_recv_prev: `torch.Tensor` if `recv_prev` is :obj:`True`, `None` otherwise.
- tensor_recv_next: `torch.Tensor` if `recv_next` is :obj:`True`, `None` otherwise.
"""
# Create placeholder tensors for receive in forward and backward directions if needed.
tensor_recv_prev = None
tensor_recv_next = None
if tensor_shape is None:
# In megatron, `tensor_shape` is set to `(args.seq_length, args.micro_batch_size, args.hidden_size)`
raise RuntimeError(
"`tensor_shape` must be specified. Common `tensor_shape` is `(seq_length, micro_batch_size, hidden_size)`")
if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = (reduce(operator.mul, tensor_shape, 1) // parallel_state.get_tensor_model_parallel_world_size(),)
else:
tensor_chunk_shape = tensor_shape
# NOTE(mkozuki): In PyTorch AMP, i.e. `torch.cuda.amp.autocast` context, activation tensors can be either FP32,
# FP16, or BF16 and there's no way to tell the dtypes of tensors on different devices in general.
# It might be possible if we restrict model architecture.
# dtype = params_dtype or torch.float
# if fp32_residual_connection:
# dtype = torch.float
# if dtype_ is not None:
# dtype = dtype_
# requires_grad = False
if dtype_ != torch.float32 or params_dtype is not None:
if torch.distributed.get_rank() == 0:
warnings.warn("Tensor P2P communications are executed in FP32")
dtype = torch.float32
requires_grad = True
if recv_prev:
tensor_recv_prev = torch.empty(
tensor_chunk_shape,
requires_grad=requires_grad,
device=torch.cuda.current_device(),
dtype=dtype,
)
if recv_next:
tensor_recv_next = torch.empty(
tensor_chunk_shape,
requires_grad=requires_grad,
device=torch.cuda.current_device(),
dtype=dtype,
)
# Split tensor into smaller chunks if using scatter-gather optimization.
if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
if tensor_send_next is not None:
tensor_send_next = split_tensor_into_1d_equal_chunks(tensor_send_next)
if tensor_send_prev is not None:
tensor_send_prev = split_tensor_into_1d_equal_chunks(tensor_send_prev)
# Send tensors in both the forward and backward directions as appropriate.
_run_p2pops(tensor_send_prev, tensor_send_next, tensor_recv_prev, tensor_recv_next)
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
# If using scatter-gather optimization, gather smaller chunks.
if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
if recv_prev:
tensor_recv_prev = (
gather_split_1d_tensor(tensor_recv_prev)
.view(tensor_shape)
.requires_grad_()
)
if recv_next:
tensor_recv_next = (
gather_split_1d_tensor(tensor_recv_next)
.view(tensor_shape)
.requires_grad_()
)
return tensor_recv_prev, tensor_recv_next
def recv_forward(
tensor_shape: Shape,
override_scatter_gather_tensors_in_pipeline: bool = False,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> torch.Tensor:
"""Receive tensor from previous rank in pipeline (forward receive)."""
if parallel_state.is_pipeline_first_stage():
return None
if timers is not None:
timers("forward-recv").start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=True,
recv_next=False,
tensor_shape=tensor_shape,
override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("forward-recv").stop()
return input_tensor
def recv_backward(
tensor_shape: Shape = None,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
):
"""Receive tensor from next rank in pipeline (backward receive)."""
if parallel_state.is_pipeline_last_stage():
return None
if timers is not None:
timers("backward-recv").start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("backward-recv").stop()
return output_tensor_grad
def send_forward(
output_tensor: torch.Tensor,
override_scatter_gather_tensors_in_pipeline: bool = False,
tensor_shape: Shape = None,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> None:
"""Send tensor to next rank in pipeline (forward send)."""
if parallel_state.is_pipeline_last_stage():
return
if timers is not None:
timers("forward-send").start()
_communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=False,
override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("forward-send").stop()
def send_backward(
input_tensor_grad: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> None:
"""Send tensor to previous rank in pipeline (backward send)."""
if parallel_state.is_pipeline_first_stage():
return
if timers is not None:
timers("backward-send").start()
_communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("backward-send").stop()
def send_forward_recv_backward(
output_tensor: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> None:
"""Batched send and recv with next rank in pipeline."""
if parallel_state.is_pipeline_last_stage():
return None
if timers is not None:
timers("forward-send-backward-recv").start()
_, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("forward-send-backward-recv").stop()
return output_tensor_grad
def send_backward_recv_forward(
input_tensor_grad: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> torch.Tensor:
"""Batched send and recv with previous rank in pipeline."""
if parallel_state.is_pipeline_first_stage():
return None
if timers is not None:
timers("backward-send-forward-recv").start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=True,
recv_next=False,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("backward-send-forward-recv").stop()
return input_tensor
def send_forward_recv_forward(
output_tensor: torch.Tensor,
recv_prev: bool,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> torch.Tensor:
"""Batched recv from previous rank and send to next rank in pipeline."""
if timers is not None:
timers("forward-send-forward-recv").start()
input_tensor, _ = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=False,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("forward-send-forward-recv").stop()
return input_tensor
def send_backward_recv_backward(
input_tensor_grad: torch.Tensor,
recv_next: bool,
tensor_shape: Shape,
*,
dtype: torch.dtype = torch.float,
timers: _Timers = None,
) -> torch.Tensor:
"""Batched recv from next rank and send to previous rank in pipeline."""
if timers is not None:
timers("backward-send-backward-recv").start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("backward-send-backward-recv").stop()
return output_tensor_grad
def send_forward_backward_recv_forward_backward(
output_tensor: torch.Tensor,
input_tensor_grad: torch.Tensor,
recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
):
"""Batched send and recv with previous and next ranks in pipeline."""
if timers is not None:
timers("forward-backward-send-forward-backward-recv").start()
input_tensor, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
)
if timers is not None:
timers("forward-backward-send-forward-backward-recv").stop()
return input_tensor, output_tensor_grad
import warnings
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import _forward_backward_pipelining_with_interleaving
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import (
forward_backward_pipelining_without_interleaving,
)
class ExperimentalWarning(Warning):
pass
def get_forward_backward_func(
virtual_pipeline_model_parallel_size, pipeline_model_parallel_size,
):
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if virtual_pipeline_model_parallel_size is not None:
if get_num_microbatches() % pipeline_model_parallel_size != 0:
msg = "number of microbatches is not divisible by pipeline-parallel size when using interleaved schedule"
raise RuntimeError(msg)
warnings.warn(
"Pipeline Model Parallel with interleaving scheduling is experimental. "
f"To use Pipeline Parallel without interleaving, set `virtual_pipeline_model_parallel_size` to `None`: {virtual_pipeline_model_parallel_size}",
ExperimentalWarning
)
forward_backward_func = _forward_backward_pipelining_with_interleaving
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
return forward_backward_func
__all__ = [
"get_forward_backward_func",
]
# NOTE (mkozuki): For simplicity, tentatively `timers` related operations are commented out.
from typing import Any, Callable, Dict, List, Tuple, Union, Optional
import torch
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import unwrap_model
from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes
Batch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]
LossFunc = Callable[[torch.Tensor], torch.Tensor]
FwdStepFunc = Callable[[Batch, torch.nn.Module], Tuple[torch.Tensor, LossFunc]]
def build_model(
model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module],
wrap_with_ddp: bool = True,
virtual_pipeline_model_parallel_size: Optional[int] = None,
*args,
**kwargs
) -> List[torch.nn.Module]:
"""Build the model satisfying pipeline model parallel requirements.
This function sets `pre_process` and `post_process` to `**kwargs` and pass `*args` and `**kwargs` to
`model_provider_func`.
Args:
model_provider_func: A function which takes `*args` and `**kwargs` and returns a `nn.Module`.
wrap_with_ddp: If :obj:`True`, wrap the instantiated model
with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`.
virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel.
*args: arguments for model provider func
**kwargs: Keyword arguments for model provider func
Returns:
a list of `nn.Module`(s). If `virtual_pipeline_model_parallel_size` is not None,
the list has multiple models, otherwise one.
"""
if (
parallel_state.get_pipeline_model_parallel_world_size() > 1 and
virtual_pipeline_model_parallel_size is not None
):
model = []
for i in range(virtual_pipeline_model_parallel_size):
cur_args = args
cur_kwargs = kwargs
parallel_state.set_virtual_pipeline_model_parallel_rank(i)
# Set pre_process and post_process only after virtual rank is set.
pre_process = parallel_state.is_pipeline_first_stage()
post_process = parallel_state.is_pipeline_last_stage()
cur_kwargs.update({
"pre_process": pre_process,
"post_process": post_process,
})
this_model = model_provider_func(*cur_args, **cur_kwargs)
model.append(this_model)
else:
cur_args = args
cur_kwargs = kwargs
pre_process = parallel_state.is_pipeline_first_stage()
post_process = parallel_state.is_pipeline_last_stage()
cur_kwargs.update({
"pre_process": pre_process,
"post_process": post_process,
})
model = model_provider_func(*cur_args, **cur_kwargs)
if not isinstance(model, list):
model = [model]
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
for model_module in model:
for param in model_module.parameters():
set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters.
if parallel_state.get_data_parallel_rank() == 0:
msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format(
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_pipeline_model_parallel_rank(),
sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model])
)
print(msg, flush=True)
# GPU allocation.
for model_module in model:
model_module.cuda(torch.cuda.current_device())
if wrap_with_ddp:
i = torch.cuda.current_device()
model = [
torch.nn.parallel.distributed.DistributedDataParallel(
model_module,
device_ids=[i],
output_device=i,
process_group=parallel_state.get_data_parallel_group(),
)
for model_module in model
]
return model
def _get_params_for_weight_decay_optimization(
model: Union[torch.nn.Module, List[torch.nn.Module]],
) -> Dict[str, torch.nn.Parameter]:
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and biases will have no weight decay but the rest will.
"""
modules = listify_model(model)
from apex.normalization.fused_layer_norm import FusedLayerNorm # NOQA
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module in modules:
for module_ in module.modules():
if isinstance(module_, FusedLayerNorm):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None])
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params
def forward_step(
forward_step_func: FwdStepFunc,
batch: Batch,
model: torch.nn.Module,
input_tensor: Optional[torch.Tensor],
losses_reduced: List[torch.Tensor],
):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor.
Args:
forward_step_func: Model specific function. This takes a minibatch and model as its arguments and
returns the model's output and the loss function.
batch: minibatch
model: unwrappable model
input_tensor:
losses_reduced:
Returns:
output_tensor
"""
# timers = get_timers()
# timers("forward-compute").start()
unwrapped_model = unwrap_model(model)
# NOTE (mkozuki): The passed `model` is expected to implement `set_input_tensor`.
# See https://github.com/NVIDIA/Megatron-LM/blob/5ac5571ba0265af4c491ee0af1508ca7589450c6/megatron/model/transformer.py#L679 # NOQA
# for the details of `set_input_tensor`.
unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(batch, model)
# print(f"forward_step| pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()} is_pipeline_last_stage?: {parallel_state.is_pipeline_last_stage()}")
if parallel_state.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
# timers("forward-compute").stop()
return output_tensor
def backward_step(
input_tensor: Optional[torch.Tensor],
output_tensor: torch.Tensor,
output_tensor_grad: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage).
Args:
input_tensor:
output_tensor:
output_tensor_grad:
Returns:
input_tensor_grad
"""
# timers = get_timers()
# timers("backward-compute").start()
# Retain the grad on the input_tensor.
# if parallel_state.get_pipeline_model_parallel_rank() == 0:
# print(f"{input_tensor}, {output_tensor}, {output_tensor_grad}")
if input_tensor is not None:
input_tensor.retain_grad()
# Backward pass.
# if output_tensor_grad is None:
# output_tensor = optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
input_tensor_grad = None
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
# timers("backward-compute").stop()
return input_tensor_grad
from contextlib import contextmanager
from typing import List, Union
import torch
from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.schedules.common import Batch, FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.log_util import get_transformer_logger
_all__ = ["forward_backward_no_pipelining"]
_logger = get_transformer_logger(__name__)
@contextmanager
def placeholder_handler():
try:
yield
finally:
pass
def forward_backward_no_pipelining(
forward_step_func: FwdStepFunc,
batch: Batch,
model: Union[torch.nn.Module, List[torch.nn.Module]],
*,
forward_only: bool,
**kwargs,
):
"""Run forward and backward passes with no pipeline parallelism (no inter-stage communication).
This pipeline parallel scheduling handles the last microbatch differently to synchronize gradients.
Args:
forward_step_func: A function which takes a minibatch and model as its arguments and
returns model's forward output and the loss function.
The loss function is supposed to take one `torch.Tensor` and
return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
batch: A List of torch.Tensors
model: A `torch.nn.Module` or a list of `torch.nn.Module`.
Keyword args:
forward_only:
**kwargs: Added to handle `tensor_shape` which has no effect on this function.
Returns:
a list of dictionaries of loss `torch.Tensor`s if the last stage, empty list otherwise.
"""
model = listify_model(model)
if len(model) != 1:
msg = f"`model` is expected be a `nn.Module`, but {type(model)}"
raise RuntimeError(msg)
model = model[0]
context_handler = placeholder_handler
if isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel):
context_handler = model.no_sync
losses_reduced = []
input_tensor, output_tensor_grad = None, None
num_micro_batches = get_num_microbatches()
with context_handler():
for i in range(num_micro_batches - 1):
_logger.info(f"Iter {i} of {num_micro_batches - 1}")
cur_micro_batch = get_kth_microbatch(batch, i)
_logger.debug("Call `forward_step`")
output_tensor = forward_step(
forward_step_func, cur_micro_batch, model, input_tensor, losses_reduced)
if not forward_only:
_logger.debug("Call `backward_step`")
backward_step(input_tensor, output_tensor, output_tensor_grad)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
_logger.info("Cooldown")
_logger.debug("Call `forward_step`")
output_tensor = forward_step(
forward_step_func, get_kth_microbatch(batch, num_micro_batches - 1), model, input_tensor, losses_reduced
)
if not forward_only:
_logger.debug("Call `backward_step`")
backward_step(input_tensor, output_tensor, output_tensor_grad)
return losses_reduced
from typing import List, Union, Optional
import torch
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import p2p_communication
from apex.transformer.pipeline_parallel.schedules.common import Batch, FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.log_util import get_transformer_logger
__all__ = ["_forward_backward_pipelining_with_interleaving"]
_logger = get_transformer_logger(__name__)
# TODO (mkozuki): Reduce cyclomatic complexity
def _forward_backward_pipelining_with_interleaving(
forward_step_func: FwdStepFunc,
batch: List[Batch],
model: List[torch.nn.Module],
*,
forward_only: bool,
tensor_shape: Optional[Union[List[int], torch.Size]] = None,
):
"""Run interleaved 1F1B schedule with communication between pipeline stages as needed.
This function assumes `batch` and `model` is a list of `Batch`'s and a list of `torch.nn.Module`, respectively.
This means that model is split into model chunks.
This pipeline parallel scheduling consists of three steps:
1. warmup
2. 1F1B a.k.a. steady state
3. cooldown
Note that if `forward_only` this scheduling consists of only warmup phase.
Args:
forward_step_func: A function which takes a minibatch and model as its arguments and
returns model's forward output and the loss function.
The loss function is supposed to take one `torch.Tensor` and
return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
batch: A minibatch, i.e., a list of `torch.Tensor`'s.
model: A `torch.nn.Module` or a list of `torch.nn.Module`.
Keyword args:
forward_only:
tensor_shape: Shape of tensor.
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
"""
if not isinstance(model, list):
raise RuntimeError("`model` must be a list of `nn.Module`'s'")
num_model_chunks = len(model)
input_tensors = [[] for _ in range(num_model_chunks)]
output_tensors = [[] for _ in range(num_model_chunks)]
curr_iters = [0 for _ in range(num_model_chunks)]
losses_reduced = []
if not forward_only:
output_tensor_grads = [[] for _ in range(num_model_chunks)]
pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()
# Compute number of warmup and remaining microbatches.
num_microbatches = get_num_microbatches() * num_model_chunks
all_warmup_microbatches = False
if forward_only:
num_warmup_microbatches = num_microbatches
else:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if get_num_microbatches() == pipeline_parallel_size:
num_warmup_microbatches = num_microbatches
all_warmup_microbatches = True
else:
num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
_logger.info(
f"num_microbatches: {num_microbatches}, "
f"num_warmup_microbatches: {num_warmup_microbatches}, "
f"num_microbatches_remaining: {num_microbatches_remaining}"
)
###################################################################################################################
# Helper function definitions.
###################################################################################################################
def get_model_chunk_id(microbatch_id: int, forward: bool) -> int:
"""Helper function to get the model chunk ID given the iteration number."""
pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
if not forward:
model_chunk_id = num_model_chunks - model_chunk_id - 1
return model_chunk_id
def forward_step_helper(microbatch_id, curr_iters):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
# forward step
if (
parallel_state.is_pipeline_first_stage() and
len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id])
):
input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1]
output_tensor = forward_step(
forward_step_func,
get_kth_microbatch(batch, curr_iters[model_chunk_id]),
model[model_chunk_id],
input_tensor,
losses_reduced,
)
curr_iters[model_chunk_id] += 1
output_tensors[model_chunk_id].append(output_tensor)
# if forward-only, no need to save tensors for a backward pass
if forward_only:
input_tensors[model_chunk_id].pop()
output_tensors[model_chunk_id].pop()
return output_tensor
def backward_step_helper(microbatch_id):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if parallel_state.is_pipeline_last_stage():
if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad)
return input_tensor_grad
###################################################################################################################
# Run warmup forward passes.
###################################################################################################################
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(p2p_communication.recv_forward(tensor_shape=tensor_shape))
_logger.info("Warmup phase")
for k in range(num_warmup_microbatches):
_logger.debug(f"warmup iter: {k} / {num_warmup_microbatches}")
output_tensor = forward_step_helper(k, curr_iters)
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
recv_prev = True
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
if next_forward_model_chunk_id == 0:
recv_prev = False
if k == (num_microbatches - 1):
recv_prev = False
_logger.debug(f"next fwd model chunk ID: {next_forward_model_chunk_id}, recv_prev: {recv_prev}")
# Don't send tensor downstream if on last stage.
if parallel_state.is_pipeline_last_stage():
_logger.debug("Pipeline last stage, not sending tensor downstream")
output_tensor = None
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if k == (num_warmup_microbatches - 1) and not forward_only and not all_warmup_microbatches:
input_tensor_grad = None
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
_logger.debug("send fwd&bwd and receive fwd&bwd")
(
input_tensor,
output_tensor_grad,
) = p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor,
input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
else:
_logger.debug("send fwd and receive fwd")
input_tensor = p2p_communication.send_forward_recv_forward(output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
###################################################################################################################
# Run 1F1B in steady state.
###################################################################################################################
_logger.info("Steady phase")
for k in range(num_microbatches_remaining):
# Forward pass.
_logger.debug(f" steady phase iter {k} / {num_microbatches_remaining}")
forward_k = k + num_warmup_microbatches
output_tensor = forward_step_helper(forward_k, curr_iters)
# Backward pass.
backward_k = k
input_tensor_grad = backward_step_helper(backward_k)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if parallel_state.is_pipeline_last_stage():
output_tensor = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
_logger.debug(f"fwd/bwd model chunk id: {forward_model_chunk_id}/{backward_model_chunk_id}")
if parallel_state.is_pipeline_first_stage():
input_tensor_grad = None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev = True
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id = get_model_chunk_id(
forward_k - (pipeline_parallel_size - 1), forward=True
)
if next_forward_model_chunk_id == (num_model_chunks - 1):
recv_prev = False
next_forward_model_chunk_id += 1
else:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id = get_model_chunk_id(
backward_k - (pipeline_parallel_size - 1), forward=False
)
if next_backward_model_chunk_id == 0:
recv_next = False
next_backward_model_chunk_id -= 1
else:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if k == (num_microbatches_remaining - 1):
recv_prev = False
# Communicate tensors.
_logger.debug("send fwd&bwd and receive fwd&bwd")
(
input_tensor,
output_tensor_grad,
) = p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor,
input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if recv_prev:
input_tensors[next_forward_model_chunk_id].append(input_tensor)
if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
###################################################################################################################
# Run cooldown backward passes (flush out pipeline).
###################################################################################################################
_logger.info("Cooldown phase")
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks - 1].append(p2p_communication.recv_backward(tensor_shape=tensor_shape))
for k in range(num_microbatches_remaining, num_microbatches):
_logger.debug(f"cooldown iter {k} in range({num_microbatches_remaining}, {num_microbatches})")
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
if next_backward_model_chunk_id == (num_model_chunks - 1):
recv_next = False
if k == (num_microbatches - 1):
recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward(input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape)
)
return losses_reduced
from typing import Union, List, Optional
import torch
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import p2p_communication
from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.schedules.common import Batch, FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.log_util import get_transformer_logger
__all__ = ["forward_backward_pipelining_without_interleaving"]
_logger = get_transformer_logger(__name__)
def forward_backward_pipelining_without_interleaving(
forward_step_func: FwdStepFunc,
batch: Batch,
model: Union[torch.nn.Module, List[torch.nn.Module]],
*,
forward_only: bool,
tensor_shape: Optional[Union[List[int], torch.Size]] = None,
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline stages.
This pipeline parallel scheduling consists of three steps:
1. warmup
2. 1F1B a.k.a. steady state
3. cooldown if not forward_only
Args:
forward_step_func: A function which takes a minibatch and model as its arguments and
returns model's forward output and the loss function.
The loss function is supposed to take one `torch.Tensor` and
return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
batch: A minibatch, i.e., a list of `torch.Tensor`'s.
model: A `torch.nn.Module` or a list of `torch.nn.Module`.
Keyword args:
forward_only:
tensor_shape: Shape of tensor. Required for P2P communication.
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
"""
# timers = get_timers()
model = listify_model(model)
if len(model) != 1:
msg = f"`model` is expected be a `nn.Module`, but {type(model)}"
raise RuntimeError(msg)
model = model[0]
# Compute number of warmup microbatches.
num_microbatches = get_num_microbatches()
num_warmup_microbatches = (
parallel_state.get_pipeline_model_parallel_world_size()
- parallel_state.get_pipeline_model_parallel_rank()
- 1
)
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
_logger.info(
f"num_microbatches: {num_microbatches}, "
f"num_warmup_microbatches: {num_warmup_microbatches}, "
f"num_microbatches_remaining: {num_microbatches_remaining}"
)
# Input, output tensors only need to be saved when doing backward passes
input_tensors = None
output_tensors = None
if not forward_only:
input_tensors = []
output_tensors = []
losses_reduced = []
###################################################################################################################
# Run warmup forward passes.
###################################################################################################################
_logger.info("Warmup")
for i in range(num_warmup_microbatches):
_logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}")
_logger.debug("receive fwd")
input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape)
cur_microbatch = get_kth_microbatch(batch, i)
output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced)
_logger.debug("send fwd")
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape)
if not forward_only:
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
_logger.debug("recv_forward before steady state start")
input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape)
###################################################################################################################
# Run 1F1B in steady state.
###################################################################################################################
_logger.info("Steady phase")
for i in range(num_microbatches_remaining):
_logger.debug(f"steady iter: {i} / {num_microbatches_remaining}")
last_iteration = i == (num_microbatches_remaining - 1)
cur_microbatch = get_kth_microbatch(batch, i + num_warmup_microbatches)
output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced)
if forward_only:
_logger.debug("send fwd")
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape)
if not last_iteration:
_logger.debug("receive fwd (last iteration)")
input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape)
else:
_logger.debug("send fwd & receive bwd")
output_tensor_grad = p2p_communication.send_forward_recv_backward(output_tensor, tensor_shape=tensor_shape)
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Pop input_tensor and output_tensor from the start of the list for the backward pass.
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad
)
if last_iteration:
input_tensor = None
_logger.debug("send bwd")
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape)
else:
_logger.debug("send bwd and receive fwd")
input_tensor = p2p_communication.send_backward_recv_forward(
input_tensor_grad, tensor_shape=tensor_shape)
###################################################################################################################
# Run cooldown backward passes.
###################################################################################################################
_logger.info("Cooldown phase")
if not forward_only:
for i in range(num_warmup_microbatches):
_logger.debug(f"cooldown iter: {i} / {num_warmup_microbatches}")
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
_logger.debug("receive bwd")
output_tensor_grad = p2p_communication.recv_backward(tensor_shape=tensor_shape)
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad
)
_logger.debug("send bwd")
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape)
return losses_reduced
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for pipeline model parallel."""
from typing import Optional, List, Union
import torch
from torch.nn.parallel import DistributedDataParallel
from apex.multi_tensor_apply import multi_tensor_applier
from apex.transformer import parallel_state
from apex.transformer.microbatches import build_num_microbatches_calculator
from apex.transformer.pipeline_parallel._timers import _Timers
if multi_tensor_applier.available:
import amp_C
_GLOBAL_ARGS = None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
_GLOBAL_TOKENIZER = None
_GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_AUTORESUME = None
_GLOBAL_TIMERS = None
Shape = Union[List[int], torch.Size]
def listify_model(model: Union[torch.nn.Module, List[torch.nn.Module]]) -> List[torch.nn.Module]:
if isinstance(model, list):
return model
return [model]
def _ensure_var_is_initialized(var, name):
"""Make sure the input variable is not None."""
assert var is not None, "{} is not initialized.".format(name)
def _ensure_var_is_not_initialized(var, name):
"""Make sure the input variable is not None."""
assert var is None, "{} is already initialized.".format(name)
def setup_microbatch_calculator(
rank: int,
rampup_batch_size: Optional[List[int]],
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
) -> None:
global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
_ensure_var_is_not_initialized(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, 'num microbatches calculator')
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(
rank, rampup_batch_size, global_batch_size, micro_batch_size, data_parallel_size)
def _reconfigure_microbatch_calculator(
rank: int,
rampup_batch_size: Optional[List[int]],
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
) -> None:
if torch.distributed.get_rank() == 0:
import warnings
warnings.warn("This function is only for unittest")
global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(
rank, rampup_batch_size, global_batch_size, micro_batch_size, data_parallel_size)
def get_micro_batch_size():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.micro_batch_size
def get_num_microbatches():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
def get_current_global_batch_size():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size()
def update_num_microbatches(consumed_samples, consistency_check=True):
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples, consistency_check)
# note (mkozuki): Comment out in favor of `get_kth_microbatch`
def _split_batch_into_microbatch(
batch: List[torch.Tensor],
*,
_micro_batch_size: Optional[int] = None,
_global_batch_size: Optional[int] = None,
) -> List[List[torch.Tensor]]:
micro_batch_size = _micro_batch_size
global_batch_size = _global_batch_size
if micro_batch_size is None:
micro_batch_size = get_micro_batch_size()
if global_batch_size is None:
global_batch_size = get_current_global_batch_size()
for i in range(0, global_batch_size, micro_batch_size):
yield [x[i * micro_batch_size:(i + 1) * micro_batch_size] for x in batch]
# TODO(mkozuki): Support non-tensor local minibatches?
def get_kth_microbatch(batch: List[torch.Tensor], k: int) -> List[torch.Tensor]:
"""Create a list of microbatches from a list of local minibatches.
This function creates a list of `k`th microbatches from a list of local minibatches.
`a local minibatch` consists of `global_batch_size / data_parallel_size` samples.
"""
micro_batch_size = get_micro_batch_size()
return [x[k * micro_batch_size:(k + 1) * micro_batch_size] for x in batch]
def get_autoresume():
return _GLOBAL_AUTORESUME
def _set_timers():
"""Initialize timers."""
global _GLOBAL_TIMERS
_ensure_var_is_not_initialized(_GLOBAL_TIMERS, "timers")
_GLOBAL_TIMERS = _Timers()
def get_timers():
"""Return timers."""
_ensure_var_is_initialized(_GLOBAL_TIMERS, "timers")
return _GLOBAL_TIMERS
def print_rank_0(message: str) -> None:
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
def is_last_rank():
return torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1)
def print_rank_last(message):
"""If distributed is initialized, print only on last rank."""
if torch.distributed.is_initialized():
if is_last_rank():
print(message, flush=True)
else:
print(message, flush=True)
def param_is_not_shared(param: torch.nn.Parameter) -> bool:
return getattr(param, "shared", False)
def unwrap_model(model, module_instances=(DistributedDataParallel,)):
return_list = True
if not isinstance(model, list):
model = [model]
return_list = False
unwrapped_model = []
for model_module in model:
while isinstance(model_module, module_instances):
model_module = model_module.module
unwrapped_model.append(model_module)
if not return_list:
return unwrapped_model[0]
return unwrapped_model
def calc_params_l2_norm(model: torch.nn.Module, bf16: bool):
"""Calculate l2 norm of parameters """
# args = get_args()
if not isinstance(model, list):
model = [model]
# Remove duplicate params.
params_data = []
for model_ in model:
for param in model_.parameters():
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = parallel_state.param_is_not_tensor_parallel_duplicate(param)
if is_not_shared and is_not_tp_duplicate:
if bf16:
params_data.append(param.data.float())
else:
params_data.append(param.data)
# Calculate norm
dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [params_data], False # no per-parameter norm
)
norm_2 = norm * norm
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(
norm_2, op=torch.distributed.ReduceOp.SUM, group=parallel_state.get_model_parallel_group()
)
return norm_2.item() ** 0.5
def average_losses_across_data_parallel_group(losses):
"""Reduce a tensor of losses across all GPUs."""
averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(averaged_losses, group=parallel_state.get_data_parallel_group())
averaged_losses = averaged_losses / torch.distributed.get_world_size(
group=parallel_state.get_data_parallel_group()
)
return averaged_losses
def report_memory(name):
"""Simple GPU memory report."""
mega_bytes = 1024.0 * 1024.0
string = name + " memory (MB)"
string += " | allocated: {}".format(torch.cuda.memory_allocated() / mega_bytes)
string += " | max allocated: {}".format(torch.cuda.max_memory_allocated() / mega_bytes)
string += " | reserved: {}".format(torch.cuda.memory_reserved() / mega_bytes)
string += " | max reserved: {}".format(torch.cuda.max_memory_reserved() / mega_bytes)
if parallel_state.get_data_parallel_rank() == 0:
print("[Rank {}] {}".format(torch.distributed.get_rank(), string), flush=True)
def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters."""
index = 0
rank = torch.distributed.get_rank()
string = "iteration, rank, index, tensor-model-parallel, min, max, norm\n"
optimizer_ = optimizer.optimizer
for param_group in optimizer_.param_groups:
for param in param_group["params"]:
index += 1
min_ = param.data.min()
max_ = param.data.max()
norm = torch.linalg.norm(param.data)
string += "{:7d}, {:4d}, {:4d}, {:2d}, ".format(
iteration, rank, index, int(param.tensor_model_parallel)
)
string += "{:.6E}, {:.6E}, {:.6E}\n".format(min_, max_, norm)
print(string, flush=True)
# NOTE (mkozuki): APEX doesn't have anything equivalent for
# `_GLOBAL_ADLR_AUTORESUME` like Megatron-LM.
# def check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler, save: bool):
# """Check for autoresume signal and exit if it is received."""
# from apex.ppu.checkpointing import save_checkpoint
#
# autoresume = get_adlr_autoresume()
# # Add barrier to ensure consistency.
# torch.distributed.barrier()
# if autoresume.termination_requested():
# if save:
# save_checkpoint(iteration, model, optimizer, lr_scheduler)
# print_rank_0(">>> autoresume termination request found!")
# if torch.distributed.get_rank() == 0:
# autoresume.request_resume()
# print_rank_0(">>> training terminated. Returning")
# sys.exit(0)
def get_ltor_masks_and_position_ids(
data, eod_token, reset_position_ids, reset_attention_mask, eod_mask_loss
):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
micro_batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(
torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
).view(att_mask_batch, 1, seq_length, seq_length)
# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(micro_batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1) :] -= i + 1 - prev_index
prev_index = i + 1
# Convert attention mask to binary:
attention_mask = attention_mask < 0.5
return attention_mask, loss_mask, position_ids
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -15,11 +15,11 @@
"""Model parallel utility interface."""
from .cross_entropy import vocab_parallel_cross_entropy
from apex.transformer.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data
from apex.transformer.tensor_parallel.data import broadcast_data
from .layers import (
from apex.transformer.tensor_parallel.layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
......@@ -28,7 +28,7 @@ from .layers import (
copy_tensor_model_parallel_attributes,
)
from .mappings import (
from apex.transformer.tensor_parallel.mappings import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
......@@ -41,11 +41,9 @@ from .random import (
init_checkpointed_activations_memory_buffer,
model_parallel_cuda_manual_seed,
reset_checkpointed_activations_memory_buffer,
gather_split_1d_tensor,
split_tensor_into_1d_equal_chunks,
)
from .utils import divide, split_tensor_along_last_dim
from apex.transformer.tensor_parallel.utils import split_tensor_along_last_dim
__all__ = [
......@@ -71,9 +69,6 @@ __all__ = [
"init_checkpointed_activations_memory_buffer",
"model_parallel_cuda_manual_seed",
"reset_checkpointed_activations_memory_buffer",
"gather_split_1d_tensor",
"split_tensor_into_1d_equal_chunks",
# utils.py
"divide",
"split_tensor_along_last_dim",
]
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,10 +14,10 @@
# limitations under the License.
import torch
from ..parallel_state import get_tensor_model_parallel_group
from ..parallel_state import get_tensor_model_parallel_rank
from ..parallel_state import get_tensor_model_parallel_world_size
from .utils import VocabUtility
from apex.transformer.parallel_state import get_tensor_model_parallel_group
from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from apex.transformer.parallel_state import get_tensor_model_parallel_world_size
from apex.transformer.tensor_parallel.utils import VocabUtility
class _VocabParallelCrossEntropy(torch.autograd.Function):
......@@ -30,7 +30,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group()
)
# Subtract the maximum value.
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1)
# Get the partition's vocab indecies
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
......@@ -100,4 +100,4 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
def vocab_parallel_cross_entropy(vocab_parallel_logits, target):
"""Helper function for the cross entropy."""
return _VocabParallelCrossEntropy.apply(torch.clone(vocab_parallel_logits), target)
return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,9 +14,9 @@
# limitations under the License.
import torch
from ..parallel_state import get_tensor_model_parallel_group
from ..parallel_state import get_tensor_model_parallel_rank
from ..parallel_state import get_tensor_model_parallel_src_rank
from apex.transformer.parallel_state import get_tensor_model_parallel_group
from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from apex.transformer.parallel_state import get_tensor_model_parallel_src_rank
_MAX_DATA_DIM = 5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment