Unverified Commit 3ff1a10f authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

[transformer] Port Sequence Parallelism (takeover of #1396) (#1400)

* it looks possible to remove this file

* add communication collectives

* update Column|RowParallelLinear

* update checkpoint function

* update function name

* parity between public and private collectives

* row parallel linear

* column parallel linear

* sequence parallel: p2p comm

fix typo

* sequence parallel: pipeline parallel

* fix typo

* add layernorm with sequence_parallel_enabled attr

* class variable -> member variable

* fix col parallel test with sequence parallel

* Initial test of `forward_backward_pipelining_without_interleaving` with `model_type=ModelType.encoder_and_decoder`

* add cases pretending to test sequence_parallel

* Apply 2 suggestion(s) to 1 file(s)

* update sequence_parallel_enabled docstring

* update docstring: order of tensor dimensions, sequence_parallel_enabled behavior

* Divide sequence_length if sequence parallel

tensor shape should be updated if sequence parallel is enabled.

* cherry-pick https://github.com/NVIDIA/Megatron-LM/commit/8474e6e54fcb9dfa37aea039352f9fb485fb6f61

* type annotation

* Fix matmul call in RowParallelLinear

Fix `sequence_parallel_enabled` to `False` as you can see in
https://github.com/NVIDIA/Megatron-LM/blob/d898a8991d1a08d29074f87819d1bf41517e35f5/megatron/mpu/layers.py#L511-L514

* update rowparallellinear test

* fix `loss_weight` is not defined in test_layers

* @eqy's comment

* mixed fused layer norm

* fix typo

* misc

* test_layers cleanup

* Skip Bert/GPT script

Since these two models haven't gotten updated for sequence parallle, e.g. the update of the order of dimension from (batch, sequence, feature) to (sequence, batch, feature) and global variables of arguments

* debug part 1/N: comment out `x.retain_grad`

* debug part 2/N: [ColumnParallelLinear] comment out overriding of sequence_parallel_enabled

* debug 3/N: add pipeline test with parallel mlp

* Fix handling `self.input_tensor` and argument

* tp2pp4 ModelType.encoder_or_decoder is failing, which can be at my fault because the backward is blaming the output and the grad_ouptut shape don't match

* revert debug 1/N

* defer tensor model parallel size > 1

* split tensor in sequence dim

* cosmetic

* cosmetic: remove archaic comment

* enable TP>1 for encoder_and_decoder as well

* set requires_grad=True always...

* Set `scatter_gather_tensors_in_pipeline` to :obj:`False`

for the sake of nemo megatron's GPT works with sequence parallel enabled.

* brush up comment of `requires_grad()`

There's a possibility that PyTorch DistributedDataParallel hangs
when some tensor (or parameter) doesn't require grad according to @ptrblck.
This forced `requires_grad` in my understanding is different from that.

* misc changes of scatter_gather_tensors_in_pipeline comment

* guard for torch_ucc

* cosmetic changes related to tests

* update command line arguments

* update TransformerLanguageModel

* rename

* move gpt to gpt.py

* update bert

* add all_gather for params in sequence parallel region

* misc. some diffs were lost during rebasing...

* updates for non sequence parallel execution

* gpt with sequence parallel

* Apply 2 suggestion(s) to 2 file(s)

* update tensor&pipeline parallel size

* why `sequence_parallel_enabled` is not supplied!? Did I messed up when rebasing?

* cosmetic fix

* correct key is sequence_parallel_enabled
parent 57f890a7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from apex.transformer.layers.layer_norm import FastLayerNorm
from apex.transformer.layers.layer_norm import FusedLayerNorm
from apex.transformer.layers.layer_norm import MixedFusedLayerNorm
__all__ = [
"FastLayerNorm",
"FusedLayerNorm",
"MixedFusedLayerNorm",
]
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# NOTE(mkozuki): This file defines two LayerNorm that are compatible with Megatron-LM.
# while avoiding introducing the breaking change of `"sequence_parallel_enabled"` attribute into apex.normalization.FusedLayerNorm
# and apex.contrib.layer_norm.FastLayerNorm.
import warnings
import torch
from apex.normalization import FusedLayerNorm as OrigFusedLayerNorm
from apex.normalization import MixedFusedLayerNorm as OrigMixedFusedLayerNorm
try:
from apex.contrib.layer_norm import FastLayerNorm as OrigFastLayerNorm
except ImportError:
HAS_FAST_LAYER_NORM = False
else:
HAS_FAST_LAYER_NORM = True
__all__ = [
"FusedLayerNorm",
"FastLayerNorm",
"MixedFusedLayerNorm",
]
def _set_sequence_parallel_enabled(
param: torch.Tensor,
sequence_parallel_enabled: bool,
) -> None:
setattr(param, "sequence_parallel_enabled", sequence_parallel_enabled)
class FusedLayerNorm(OrigFusedLayerNorm):
def __init__(
self,
normalized_shape,
eps: float = 1e-5,
elementwise_affine: bool = True,
*,
sequence_parallel_enabled: bool = False,
):
super().__init__(
normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
)
self.sequence_parallel_enabled = sequence_parallel_enabled
if self.elementwise_affine:
_set_sequence_parallel_enabled(self.weight, self.sequence_parallel_enabled)
_set_sequence_parallel_enabled(self.bias, self.sequence_parallel_enabled)
# note: MixedFusedLayerNorm is no different from FusedLayerNorm if it's used in `torch.cuda.amp`.
class MixedFusedLayerNorm(OrigMixedFusedLayerNorm):
def __init__(
self,
normalized_shape,
eps: float = 1e-5,
**kwargs,
) -> None:
self.sequence_parallel_enabled = kwargs.get("sequence_parallel_enabled", False)
super().__init__(normalized_shape=normalized_shape, eps=eps, **kwargs)
if self.sequence_parallel_enabled:
_set_sequence_parallel_enabled(self.weight, self.sequence_parallel_enabled)
_set_sequence_parallel_enabled(self.bias, self.sequence_parallel_enabled)
if HAS_FAST_LAYER_NORM:
class FastLayerNorm(OrigFastLayerNorm):
def __init__(
self,
hidden_size,
eps: float = 1e-5,
*,
sequence_parallel_enabled: bool = False,
):
super().__init__(
hidden_size=hidden_size,
eps=eps
)
self.sequence_parallel_enabled = sequence_parallel_enabled
_set_sequence_parallel_enabled(self.weight, self.sequence_parallel_enabled)
_set_sequence_parallel_enabled(self.bias, self.sequence_parallel_enabled)
else:
class FastLayerNorm(FusedLayerNorm):
def __init__(
self,
hidden_size,
eps: float = 1e-5,
*,
sequence_parallel_enabled: bool = False,
):
warnings.warn("`apex.contrib.layer_norm.FastLayerNorm` isn't available thus falling back to `apex.normalization.FusedLayerNorm`")
super().__init__(
normalized_shape=hidden_size,
eps=eps,
elementwise_affine=True,
sequence_parallel_enabled=sequence_parallel_enabled,
)
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(mkozuki): Consider removing `timers`.
from functools import reduce
import operator
......@@ -20,12 +21,15 @@ from typing import Union, Optional, Tuple
import torch
from apex.transformer import parallel_state
from apex.transformer.log_util import get_transformer_logger
from apex.transformer.utils import split_tensor_into_1d_equal_chunks
from apex.transformer.utils import gather_split_1d_tensor
from apex.transformer.pipeline_parallel.utils import Shape
from apex.transformer.pipeline_parallel._timers import _Timers
_logger = get_transformer_logger(__name__)
class FutureTensor:
def __init__(self, tensor: torch.Tensor, waitfunc):
......@@ -42,11 +46,11 @@ class FutureTensor:
def _run_p2pops(
tensor_send_prev: Union[torch.Tensor, None],
tensor_send_next: Union[torch.Tensor, None],
tensor_recv_prev: Union[torch.Tensor, None],
tensor_recv_next: Union[torch.Tensor, None],
async_comm: bool = False
tensor_send_prev: Union[torch.Tensor, None],
tensor_send_next: Union[torch.Tensor, None],
tensor_recv_prev: Union[torch.Tensor, None],
tensor_recv_next: Union[torch.Tensor, None],
async_comm: bool = False
):
ops = []
if tensor_send_prev is not None:
......@@ -93,6 +97,11 @@ def _run_p2pops(
return (None, None, None, None)
# TODO(mkozuki): Check if it's possible to sunset `override_scatter_gather_tensors_in_pipeline`.
# TODO(mkozuki): Think about if it's possible to push some logic and arguments e.g.
# `scatter_gather_tensors_in_pipeline`, `sequence_parallel_enabled`, and
# `override_scatter_gather_tensors_in_pipeline` # to the user of
# apex.transformer forward_backwardfunctions.
def _communicate(
tensor_send_next: Optional[torch.Tensor],
tensor_send_prev: Optional[torch.Tensor],
......@@ -106,9 +115,14 @@ def _communicate(
params_dtype: Optional[torch.dtype] = None,
fp32_residual_connection: bool = False,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
) -> Tuple[Union[torch.Tensor, FutureTensor, None], Union[torch.Tensor, FutureTensor, None]]:
"""Base function for communication of tensors between stages.
.. note::
Reference https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/cfd2e2160700b7f2c1bf35298ac14bc341f4c759/megatron/p2p_communication.py#L24-L159
dtype logic: If none of ``dtype_``, ``params_dtype``, ``fp32_residual_connection`` is specified,
torch.float32 is used.
......@@ -130,6 +144,9 @@ def _communicate(
params_dtype: Optional and legacy. Defaults to torch.float. If you manually call `.half()` or `.bfloat16()` on
your model deliberately, pass this argument.
fp32_residual_connection: Optional. If :obj:`True`, move residual connections to fp32.
sequence_parallel_enabled: Set to :obj:`True` if sequence parallel is enabled.
This argument is here for consistency with Megatron-LM.
This argument has an effect on the communication optimization, not on tensor_shape update.
Returns:
tuple containing
......@@ -137,6 +154,13 @@ def _communicate(
- tensor_recv_prev: `torch.Tensor` if `recv_prev` is :obj:`True`, `None` otherwise.
- tensor_recv_next: `torch.Tensor` if `recv_next` is :obj:`True`, `None` otherwise.
"""
if async_comm and sequence_parallel_enabled:
import warnings # NOQA
class ExperimentalWarning(UserWarning): pass # NOQA
warnings.warn(
"The combination of `async_comm` and `sequence_parallel_enabled` is not well tested.",
ExperimentalWarning,
)
# Create placeholder tensors for receive in forward and backward directions if needed.
tensor_recv_prev = None
tensor_recv_next = None
......@@ -144,25 +168,45 @@ def _communicate(
# In megatron, `tensor_shape` is set to `(args.seq_length, args.micro_batch_size, args.hidden_size)`
raise RuntimeError(
"`tensor_shape` must be specified. Common `tensor_shape` is `(seq_length, micro_batch_size, hidden_size)`")
if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = (reduce(operator.mul, tensor_shape, 1) // parallel_state.get_tensor_model_parallel_world_size(),)
tensor_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
override_scatter_gather_tensors_in_pipeline_ = False
# TODO(mkozuki): Demystify hardcode False of `scatter_gather_tensors_in_pipeline` and add a testcase if possible.
# NOTE(mkozuki): This is super strange and doesn't make sense to me. I have no idea what is happening here.
# However, I can say that this hardcoding override is necessary for sequence parallel in nemo megatron to work.
# I've not managed to reproduce the hang using standalone GPT with sequence parallel.
# The hang in NeMo Megatron happens in the 3rd iteration, the last iteration of stead phase inside
# forward_backward_pipelining_without_interleaving, pipeline parallel rank of 0 (tensor model parallel world
# size of 2 and pipeline model parallel world size of 2). The commit then of APEX and NeMo were
# https://github.com/NVIDIA/apex/pull/1396/commits/3060c98dd8ba42abf7702ea9d2cff0f39ea74f45 and
# https://github.com/NVIDIA/NeMo/pull/4232/commits/1cb32dfca2ab9b20f53ebdb84476c34cb42f0205.
# The PyTorch version was 1.13.0a0+git2d354cd, for what is worth.
# Currently, indiscriminately this is set to `False`, which can lead to an unexpected performance regression
# for non sequence parallel case.
scatter_gather_tensors_in_pipeline = False
if scatter_gather_tensors_in_pipeline and not sequence_parallel_enabled:
tensor_chunk_size = int(reduce(operator.mul, tensor_shape, 1))
if tensor_chunk_size % tensor_parallel_size == 0:
tensor_chunk_shape = [tensor_chunk_size // tensor_parallel_size]
else:
tensor_chunk_shape = tensor_shape
override_scatter_gather_tensors_in_pipeline_ = True
else:
tensor_chunk_shape = tensor_shape
# The dtype logic below is copied from NVIDIA/Megatron-LM repo:
# https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/p2p_communication.py#L74-L81
# NOTE (mkozuki): Currently NeMo is implementing APEX AMP O2 style using PyTorch. In O2 style, forcing p2p comm to
# use FP32 will be a perf killer so that I decided to reanimate `dtype_` argument with the default value of `None`.
# NOTE (mkozuki): In PyTorch AMP, i.e. `torch.cuda.amp.autocast` context, activation tensors can be either FP32,
# FP16, or BF16 and there's no way to tell the dtypes of tensors on different devices in general.
# It might be possible if we restrict model architecture.
dtype = params_dtype or torch.float
if fp32_residual_connection:
dtype = torch.float
requires_grad = True
if dtype_ is not None:
dtype = dtype_
requires_grad = False
# TODO(mkozuki): Figure out why this logic of requires_grad isn't working
# when sequence_parallel_enabled=True. Otherwise, `x.retain_grad()` of
# https://github.com/crcrpar/apex/blob/069832078a652b4bd8a99db84faf953a81415ab3/apex/transformer/pipeline_parallel/schedules/common.py#L360
# fails.
# requires_grad = False
if recv_prev:
tensor_recv_prev = torch.empty(
......@@ -180,7 +224,12 @@ def _communicate(
)
# Split tensor into smaller chunks if using scatter-gather optimization.
if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
scatter_gather_optimization_doable = (
not override_scatter_gather_tensors_in_pipeline_
and scatter_gather_tensors_in_pipeline
and not sequence_parallel_enabled
)
if scatter_gather_optimization_doable:
if tensor_send_next is not None:
tensor_send_next = split_tensor_into_1d_equal_chunks(tensor_send_next)
......@@ -210,7 +259,7 @@ def _communicate(
torch.cuda.synchronize()
# If using scatter-gather optimization, gather smaller chunks.
if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
if scatter_gather_optimization_doable:
if not async_comm:
if recv_prev:
tensor_recv_prev = (
......@@ -218,7 +267,7 @@ def _communicate(
.view(tensor_shape)
.requires_grad_()
)
if recv_next:
tensor_recv_next = (
gather_split_1d_tensor(tensor_recv_next)
......@@ -254,17 +303,17 @@ def _communicate(
if tensor_recv_next is not None:
future_tensor_recv_next = FutureTensor(tensor_recv_next, tensor_recv_next_waitfunc)
return future_tensor_recv_prev, future_tensor_recv_next
return tensor_recv_prev, tensor_recv_next
def recv_forward(
tensor_shape: Shape,
override_scatter_gather_tensors_in_pipeline: bool = False,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
async_comm: bool = False,
tensor_shape: Shape,
override_scatter_gather_tensors_in_pipeline: bool = False,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Receive tensor from previous rank in pipeline (forward receive)."""
if parallel_state.is_pipeline_first_stage():
......@@ -280,6 +329,7 @@ def recv_forward(
override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline,
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
# if timers is not None:
# timers("forward-recv").stop()
......@@ -287,11 +337,12 @@ def recv_forward(
def recv_backward(
tensor_shape: Shape = None,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
async_comm: bool = False,
tensor_shape: Shape = None,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Receive tensor from next rank in pipeline (backward receive)."""
if parallel_state.is_pipeline_last_stage():
......@@ -306,6 +357,7 @@ def recv_backward(
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
# if timers is not None:
# timers("backward-recv").stop()
......@@ -313,13 +365,14 @@ def recv_backward(
def send_forward(
output_tensor: torch.Tensor,
override_scatter_gather_tensors_in_pipeline: bool = False,
tensor_shape: Shape = None,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
async_comm: bool = False,
output_tensor: torch.Tensor,
override_scatter_gather_tensors_in_pipeline: bool = False,
tensor_shape: Shape = None,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> None:
"""Send tensor to next rank in pipeline (forward send)."""
if parallel_state.is_pipeline_last_stage():
......@@ -335,19 +388,20 @@ def send_forward(
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
# if timers is not None:
# timers("forward-send").stop()
def send_backward(
input_tensor_grad: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
async_comm: bool = False,
input_tensor_grad: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> None:
"""Send tensor to previous rank in pipeline (backward send)."""
if parallel_state.is_pipeline_first_stage():
......@@ -362,18 +416,20 @@ def send_backward(
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
# if timers is not None:
# timers("backward-send").stop()
def send_forward_recv_backward(
output_tensor: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
async_comm: bool = False,
output_tensor: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Batched send and recv with next rank in pipeline."""
if parallel_state.is_pipeline_last_stage():
......@@ -388,6 +444,7 @@ def send_forward_recv_backward(
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
# if timers is not None:
# timers("forward-send-backward-recv").stop()
......@@ -395,12 +452,13 @@ def send_forward_recv_backward(
def send_backward_recv_forward(
input_tensor_grad: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
async_comm: bool = False,
input_tensor_grad: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Batched send and recv with previous rank in pipeline."""
if parallel_state.is_pipeline_first_stage():
......@@ -415,6 +473,7 @@ def send_backward_recv_forward(
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
# if timers is not None:
# timers("backward-send-forward-recv").stop()
......@@ -422,13 +481,14 @@ def send_backward_recv_forward(
def send_forward_recv_forward(
output_tensor: torch.Tensor,
recv_prev: bool,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
async_comm: bool = False,
output_tensor: torch.Tensor,
recv_prev: bool,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor]:
"""Batched recv from previous rank and send to next rank in pipeline."""
# if timers is not None:
......@@ -441,6 +501,7 @@ def send_forward_recv_forward(
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
# if timers is not None:
# timers("forward-send-forward-recv").stop()
......@@ -448,13 +509,14 @@ def send_forward_recv_forward(
def send_backward_recv_backward(
input_tensor_grad: torch.Tensor,
recv_next: bool,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
async_comm: bool = False,
input_tensor_grad: torch.Tensor,
recv_next: bool,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor]:
"""Batched recv from next rank and send to previous rank in pipeline."""
# if timers is not None:
......@@ -467,6 +529,7 @@ def send_backward_recv_backward(
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
# if timers is not None:
# timers("backward-send-backward-recv").stop()
......@@ -474,15 +537,16 @@ def send_backward_recv_backward(
def send_forward_backward_recv_forward_backward(
output_tensor: torch.Tensor,
input_tensor_grad: torch.Tensor,
recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
async_comm: bool = False,
output_tensor: torch.Tensor,
input_tensor_grad: torch.Tensor,
recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> Tuple[Union[torch.Tensor, FutureTensor], Union[torch.Tensor, FutureTensor]]:
"""Batched send and recv with previous and next ranks in pipeline."""
# if timers is not None:
......@@ -495,6 +559,7 @@ def send_forward_backward_recv_forward_backward(
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
# if timers is not None:
# timers("forward-backward-send-forward-backward-recv").stop()
......
......@@ -34,6 +34,7 @@ def _forward_backward_pipelining_with_interleaving(
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False,
deallocate_pipeline_outputs: bool = False,
sequence_parallel_enabled: bool = False,
**kwargs,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run interleaved 1F1B schedule with communication between pipeline stages as needed.
......@@ -57,13 +58,17 @@ def _forward_backward_pipelining_with_interleaving(
Keyword args:
forward_only:
tensor_shape: Shape of tensor.
tensor_shape: Shape of tensor. The tensor is expected to be 3D and its order of dimension
is supposed to be ``(sequence, batch, hidden)``.
dtype: dtype used in p2p communication. If ``None`` (default value),
torch.float32 will be used even if ``autocast`` is enabled.
grad_scaler:
disable_autocast:
deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of
each pipeline stage. Experimental.
sequence_parallel_enabled: Set to :obj:`True` for this function to handle sequence length.
When :obj:`True`, the sequence length on each tensor model parallel rank is updated
to :math:`original\_sequence\_length / tensor\_model\_parallel\_world\_size`.
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
......@@ -77,6 +82,15 @@ def _forward_backward_pipelining_with_interleaving(
"This option is not recommended."
)
# mypy will blame the following if statement
if sequence_parallel_enabled:
seq_length, batch_size, hidden = tensor_shape
tensor_shape = (
seq_length // parallel_state.get_tensor_model_parallel_world_size(),
batch_size,
hidden,
)
num_model_chunks: int = len(model)
input_tensors: List[List[Union[None, torch.Tensor]]] = [
[] for _ in range(num_model_chunks)
......@@ -201,7 +215,11 @@ def _forward_backward_pipelining_with_interleaving(
###################################################################################################################
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(
p2p_communication.recv_forward(tensor_shape=tensor_shape, dtype=dtype)
p2p_communication.recv_forward(
tensor_shape=tensor_shape,
dtype=dtype,
sequence_parallel_enabled=sequence_parallel_enabled,
)
)
_logger.info("Warmup phase")
for k in range(num_warmup_microbatches):
......@@ -247,6 +265,7 @@ def _forward_backward_pipelining_with_interleaving(
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype=dtype,
sequence_parallel_enabled=sequence_parallel_enabled,
)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
else:
......@@ -256,9 +275,10 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev=recv_prev,
tensor_shape=tensor_shape,
dtype=dtype,
sequence_parallel_enabled=sequence_parallel_enabled,
)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)
###################################################################################################################
# Run 1F1B in steady state.
......@@ -339,6 +359,7 @@ def _forward_backward_pipelining_with_interleaving(
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype=dtype,
sequence_parallel_enabled=sequence_parallel_enabled,
)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)
......@@ -356,7 +377,11 @@ def _forward_backward_pipelining_with_interleaving(
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks - 1].append(
p2p_communication.recv_backward(tensor_shape=tensor_shape, dtype=dtype)
p2p_communication.recv_backward(
tensor_shape=tensor_shape,
dtype=dtype,
sequence_parallel_enabled=sequence_parallel_enabled,
)
)
for k in range(num_microbatches_remaining, num_microbatches):
_logger.debug(
......@@ -376,6 +401,7 @@ def _forward_backward_pipelining_with_interleaving(
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype=dtype,
sequence_parallel_enabled=sequence_parallel_enabled,
)
)
......
......@@ -31,7 +31,19 @@ def get_tensor_shapes(
*,
tensor_shape: Union[List[int], torch.Size],
decoder_sequence_length: Optional[int] = None,
sequence_parallel_enabled: bool = False,
) -> Sequence[Sequence[int]]:
"""Get tensors shapes
Args:
rank: pipeline parallel rank
model_type:
Keyword Args:
tensor_shape:
decoder_sequence_length:
sequence_parallel_enabled:
"""
# Determine right tensor sizes (based on position of rank with respect to split
# rank) and model size.
# Send two tensors if model is T5 and rank is in decoder stage:
......@@ -44,21 +56,27 @@ def get_tensor_shapes(
len(tensor_shape) == 3
), f"`tensor_shape` should be [sequence_length, micro_batch_size, hidden_size] but {tensor_shape}"
sequence_length, micro_batch_size, hidden_size = tensor_shape
seq_len = sequence_length
if sequence_parallel_enabled:
seq_len = sequence_length // parallel_state.get_tensor_model_parallel_world_size()
tensor_shapes = []
if model_type == ModelType.encoder_and_decoder:
if decoder_sequence_length is None:
raise ValueError("`decoder_sequence_length` is required for `ModelType.encoder_and_decoder`")
dec_seq_len = decoder_sequence_length
if sequence_parallel_enabled:
dec_seq_len = decoder_sequence_length // parallel_state.get_tensor_model_parallel_world_size()
if parallel_state.is_pipeline_stage_before_split(rank):
# If next rank is after split, then need transpose for encoder_hidden_state.
if parallel_state.is_pipeline_stage_before_split(rank + 1):
tensor_shapes.append((sequence_length, micro_batch_size, hidden_size))
tensor_shapes.append((seq_len, micro_batch_size, hidden_size))
else:
tensor_shapes.append((micro_batch_size, sequence_length, hidden_size))
tensor_shapes.append((dec_seq_len, micro_batch_size, hidden_size))
else:
tensor_shapes.append((decoder_sequence_length, micro_batch_size, hidden_size))
tensor_shapes.append((micro_batch_size, sequence_length, hidden_size))
tensor_shapes.append((dec_seq_len, micro_batch_size, hidden_size))
tensor_shapes.append((seq_len, micro_batch_size, hidden_size))
else:
tensor_shapes.append((sequence_length, micro_batch_size, hidden_size))
tensor_shapes.append((seq_len, micro_batch_size, hidden_size))
return tensor_shapes
......@@ -67,13 +85,21 @@ def recv_forward(
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
) -> List[Union[None, torch.Tensor, FutureTensor]]:
input_tensors = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
input_tensors.append(None)
else:
input_tensors.append(p2p_communication.recv_forward(tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm))
input_tensors.append(
p2p_communication.recv_forward(
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
)
return input_tensors
......@@ -82,13 +108,21 @@ def recv_backward(
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
) -> List[Union[None, torch.Tensor, FutureTensor]]:
output_tensor_grads = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
output_tensor_grads.append(None)
else:
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm))
output_tensor_grads.append(
p2p_communication.recv_backward(
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
)
return output_tensor_grads
......@@ -98,13 +132,20 @@ def send_forward(
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
) -> None:
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm)
p2p_communication.send_forward(
output_tensor,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
def send_backward(
......@@ -113,13 +154,20 @@ def send_backward(
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
) -> None:
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm)
p2p_communication.send_backward(
input_tensor_grad,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
def send_forward_recv_backward(
......@@ -128,6 +176,7 @@ def send_forward_recv_backward(
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
) -> List[Union[None, torch.Tensor, FutureTensor]]:
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
......@@ -136,7 +185,13 @@ def send_forward_recv_backward(
if tensor_shape is None:
output_tensor_grads.append(None)
continue
output_tensor_grad = p2p_communication.send_forward_recv_backward(output_tensor, tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm)
output_tensor_grad = p2p_communication.send_forward_recv_backward(
output_tensor,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
output_tensor_grads.append(output_tensor_grad)
return output_tensor_grads
......@@ -147,6 +202,7 @@ def send_backward_recv_forward(
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
) -> List[Union[None, torch.Tensor, FutureTensor]]:
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
......@@ -155,7 +211,13 @@ def send_backward_recv_forward(
if tensor_shape is None:
input_tensors.append(None)
continue
input_tensor = p2p_communication.send_backward_recv_forward(input_tensor_grad, tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm)
input_tensor = p2p_communication.send_backward_recv_forward(
input_tensor_grad,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
input_tensors.append(input_tensor)
return input_tensors
......@@ -173,6 +235,7 @@ def forward_backward_pipelining_without_interleaving(
disable_autocast: bool = False,
deallocate_pipeline_outputs: bool = False,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
**kwargs,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run non-interleaved 1F1B schedule, with communication between pipeline stages.
......@@ -192,13 +255,17 @@ def forward_backward_pipelining_without_interleaving(
Keyword args:
forward_only:
tensor_shape: Shape of tensor. Required for P2P communication.
tensor_shape: Shape of tensor. The tensor is expected to be 3D and its order of dimension
is supposed to be ``(sequence, batch, hidden)``.
dtype: dtype used in p2p communication. If ``None`` (default value),
torch.float32 will be used even if ``autocast`` is enabled.
grad_scaler:
disable_autocast:
deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of
each pipeline stage. Experimental.
sequence_parallel_enabled: Set to :obj:`True` for this function to handle sequence length.
When :obj:`True`, the sequence length on each tensor model parallel rank is updated
to :math:`original\_sequence\_length / tensor\_model\_parallel\_world\_size`.
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
......@@ -228,10 +295,18 @@ def forward_backward_pipelining_without_interleaving(
model_type = get_model_type(model)
rank: int = parallel_state.get_pipeline_model_parallel_rank()
recv_tensor_shapes: List[List[int]] = get_tensor_shapes(
rank - 1, model_type, tensor_shape=tensor_shape, decoder_sequence_length=decoder_sequence_length
rank - 1,
model_type,
tensor_shape=tensor_shape,
decoder_sequence_length=decoder_sequence_length,
sequence_parallel_enabled=sequence_parallel_enabled,
)
send_tensor_shapes: List[List[int]] = get_tensor_shapes(
rank, model_type, tensor_shape=tensor_shape, decoder_sequence_length=decoder_sequence_length
rank,
model_type,
tensor_shape=tensor_shape,
decoder_sequence_length=decoder_sequence_length,
sequence_parallel_enabled=sequence_parallel_enabled,
)
_logger.info(
......@@ -251,7 +326,12 @@ def forward_backward_pipelining_without_interleaving(
for i in range(num_warmup_microbatches):
_logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}")
_logger.debug("receive fwd")
input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm)
input_tensor = recv_forward(
tensor_shapes=recv_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i)
output_tensor = forward_step(
forward_step_func,
......@@ -263,7 +343,13 @@ def forward_backward_pipelining_without_interleaving(
disable_autocast,
)
_logger.debug("send fwd")
send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype, async_comm=async_comm)
send_forward(
output_tensor,
tensor_shapes=send_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
if not forward_only:
input_tensors.append(input_tensor)
......@@ -297,15 +383,32 @@ def forward_backward_pipelining_without_interleaving(
)
if forward_only:
_logger.debug("send fwd")
send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype, async_comm=async_comm)
send_forward(
output_tensor,
tensor_shapes=send_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
if not last_iteration:
_logger.debug("receive fwd (last iteration)")
input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm)
input_tensor = recv_forward(
tensor_shapes=recv_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
else:
_logger.debug("send fwd & receive bwd")
output_tensor_grad = send_forward_recv_backward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype, async_comm=async_comm)
output_tensor_grad = send_forward_recv_backward(
output_tensor,
tensor_shapes=send_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
......@@ -328,10 +431,22 @@ def forward_backward_pipelining_without_interleaving(
if last_iteration:
input_tensor = None
_logger.debug("send bwd")
send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm)
send_backward(
input_tensor_grad,
tensor_shapes=recv_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
else:
_logger.debug("send bwd and receive fwd")
input_tensor = send_backward_recv_forward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm)
input_tensor = send_backward_recv_forward(
input_tensor_grad,
tensor_shapes=recv_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
###################################################################################################################
# Run cooldown backward passes.
###################################################################################################################
......@@ -343,7 +458,12 @@ def forward_backward_pipelining_without_interleaving(
output_tensor = output_tensors.pop(0)
_logger.debug("receive bwd")
output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes, dtype=dtype, async_comm=async_comm)
output_tensor_grad = recv_backward(
tensor_shapes=send_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
input_tensor_grad = backward_step(
input_tensor,
......@@ -355,6 +475,12 @@ def forward_backward_pipelining_without_interleaving(
)
_logger.debug("send bwd")
send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm)
send_backward(
input_tensor_grad,
tensor_shapes=recv_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
return losses_reduced
......@@ -32,6 +32,7 @@ from apex.transformer.tensor_parallel.mappings import (
gather_from_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region,
scatter_to_sequence_parallel_region,
)
from .random import (
......@@ -62,6 +63,7 @@ __all__ = [
"gather_from_tensor_model_parallel_region",
"reduce_from_tensor_model_parallel_region",
"scatter_to_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region",
# random.py
"checkpoint",
"get_cuda_rng_tracker",
......
......@@ -16,6 +16,9 @@
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
from typing import Optional, Dict, Tuple, List
import warnings
import torch
import torch.nn.functional as F
import torch.nn.init as init
......@@ -38,6 +41,9 @@ from apex.transformer.tensor_parallel.mappings import (
from apex.transformer.tensor_parallel.mappings import (
scatter_to_tensor_model_parallel_region,
)
from apex.transformer.tensor_parallel.mappings import (
reduce_scatter_to_sequence_parallel_region,
)
from apex.transformer.tensor_parallel.random import get_cuda_rng_tracker
from apex.transformer.tensor_parallel.utils import VocabUtility
from apex.transformer.log_util import get_transformer_logger
......@@ -60,13 +66,13 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
}
def param_is_not_tensor_parallel_duplicate(param):
def param_is_not_tensor_parallel_duplicate(param: torch.Tensor) -> bool:
return (
hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel
) or (get_tensor_model_parallel_rank() == 0)
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
def set_tensor_model_parallel_attributes(tensor: torch.Tensor, is_parallel: bool, dim: int, stride: int) -> None:
# Make sure the attributes are not set.
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
assert not hasattr(tensor, attribute)
......@@ -76,7 +82,7 @@ def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
setattr(tensor, "partition_stride", stride)
def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor: torch.Tensor) -> None:
def maybe_set(attribute, value):
if not hasattr(tensor, attribute):
setattr(tensor, attribute, value)
......@@ -85,7 +91,7 @@ def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])
def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
def copy_tensor_model_parallel_attributes(destination_tensor: torch.Tensor, source_tensor: torch.Tensor) -> None:
def maybe_copy(attribute):
if hasattr(source_tensor, attribute):
setattr(destination_tensor, attribute, getattr(source_tensor, attribute))
......@@ -95,7 +101,14 @@ def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
def _initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU."""
"""Initialize affine weight for model parallel on GPU.
Args:
weight (Parameter):
init_method (Callable[[Tensor], None]): Taking a Tensor and initialize its elements.
partition_dim (int): Dimension to apply partition.
stride (int):
"""
set_tensor_model_parallel_attributes(
tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
......@@ -164,14 +177,14 @@ class VocabParallelEmbedding(torch.nn.Module):
def __init__(
self,
num_embeddings,
embedding_dim,
num_embeddings: int,
embedding_dim: int,
init_method=init.xavier_normal_,
*,
params_dtype=torch.float32,
use_cpu_initialization=False,
params_dtype: torch.dtype=torch.float32,
use_cpu_initialization: bool = False,
):
super(VocabParallelEmbedding, self).__init__()
super().__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
......@@ -183,7 +196,7 @@ class VocabParallelEmbedding(torch.nn.Module):
self.sparse = False
self._weight = None
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
# Divide the weight matrix along the vocabulary dimension.
(
self.vocab_start_index,
self.vocab_end_index,
......@@ -256,18 +269,48 @@ class VocabParallelEmbedding(torch.nn.Module):
return output
class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
"""Linear layer execution with asynchronous all-reduce and gradient accumulation fusion in backprop."""
class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
"""Linear layer execution with asynchronous communication and gradient accumulation fusion in backprop."""
@staticmethod
def forward(
ctx, input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce
ctx,
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
gradient_accumulation_fusion: bool,
async_grad_allreduce: bool,
sequence_parallel_enabled: bool,
use_16bit_in_wgrad_accum_fusion: bool = False,
):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce
output = torch.matmul(input, weight.t())
ctx.sequence_parallel_enabled = sequence_parallel_enabled
ctx.use_16bit_in_wgrad_accum_fusion = use_16bit_in_wgrad_accum_fusion
if ctx.sequence_parallel_enabled:
world_size = get_tensor_model_parallel_world_size()
# `input` is supposed to be 3D and its order of dimension is [sequence, batch, hidden]
shape = list(input.shape)
shape[0] *= world_size
all_gather_buffer = torch.empty(
shape,
dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
torch.distributed.all_gather(
list(all_gather_buffer.chunk(world_size)),
input,
group=get_tensor_model_parallel_group(),
)
total_input = all_gather_buffer
else:
total_input = input
output = torch.matmul(total_input, weight.t())
if bias is not None:
output = output + bias
return output
......@@ -276,12 +319,40 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
if ctx.sequence_parallel_enabled:
world_size = get_tensor_model_parallel_world_size()
shape = list(input.shape)
shape[0] *= world_size
all_gather_buffer = torch.empty(
shape,
dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
handle = torch.distributed.all_gather(
list(all_gather_buffer.chunk(get_tensor_model_parallel_world_size())),
input,
group=get_tensor_model_parallel_group(),
async_op=True,
)
# Delay the start of input gradient computation shortly (3us) to have gather scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
total_input = all_gather_buffer
else:
total_input = input
grad_input = grad_output.matmul(weight)
if ctx.sequence_parallel_enabled:
handle.wait()
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(
grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
)
input = input.view(input.shape[0] * input.shape[1], input.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
......@@ -291,87 +362,80 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
if ctx.gradient_accumulation_fusion:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
input, grad_output, weight.main_grad
if ctx.sequence_parallel_enabled:
assert not ctx.async_grad_allreduce
sub_grad_input = torch.empty(input.shape, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False)
handle = torch.distributed.reduce_scatter(
sub_grad_input,
list(grad_input.chunk(get_tensor_model_parallel_world_size())),
group=get_tensor_model_parallel_group(),
async_op=True,
)
# Delay the start of weight gradient computation shortly (3us) to have reduce scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
if ctx.gradient_accumulation_fusion:
if not ctx.use_16bit_in_wgrad_accum_fusion:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
total_input, grad_output, weight.main_grad
)
else:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
total_input, grad_output, weight.main_grad
)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(input)
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.sequence_parallel_enabled:
handle.wait()
return sub_grad_input, grad_weight, grad_bias, None, None, None, None
if ctx.async_grad_allreduce:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None
return grad_input, grad_weight, grad_bias, None, None, None, None
def linear_with_grad_accumulation_and_async_allreduce(
input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce,
):
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
gradient_accumulation_fusion: bool,
async_grad_allreduce: bool,
sequence_parallel_enabled: bool,
) -> torch.Tensor:
args = _cast_if_autocast_enabled(
input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce
input,
weight,
bias,
gradient_accumulation_fusion,
async_grad_allreduce,
sequence_parallel_enabled,
False, # use_16bit_in_wgrad_accum_fusion
)
with torch.cuda.amp.autocast(enabled=False):
return LinearWithGradAccumulationAndAsyncAllreduce.apply(*args)
class LinearWithGradAccumulationAndAsyncAllreduceIn16Bit(torch.autograd.Function):
"""Linear layer execution with asynchronous all-reduce and gradient accumulation fusion in backprop."""
@staticmethod
def forward(
ctx, input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce
):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce
output = torch.matmul(input, weight.t())
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
grad_input = grad_output.matmul(weight)
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(
grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
)
input = input.view(input.shape[0] * input.shape[1], input.shape[2])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True
)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
if ctx.gradient_accumulation_fusion:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
input, grad_output, weight.main_grad
)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None
return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
def linear_with_grad_accumulation_and_async_allreduce_in16bit(
input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce,
):
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
gradient_accumulation_fusion: bool,
async_grad_allreduce: bool,
sequence_parallel_enabled: bool,
) -> torch.Tensor:
args = _cast_if_autocast_enabled(
input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce
input,
weight,
bias,
gradient_accumulation_fusion,
async_grad_allreduce,
sequence_parallel_enabled,
True, # use_16bit_in_wgrad_accum_fusion
)
with torch.cuda.amp.autocast(enabled=False):
return LinearWithGradAccumulationAndAsyncAllreduceIn16Bit.apply(*args)
return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
class ColumnParallelLinear(torch.nn.Module):
......@@ -380,6 +444,10 @@ class ColumnParallelLinear(torch.nn.Module):
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
.. note::
Input is supposed to be three dimensional and each dimension
is expected to be sequence, batch, and hidden feature, respectively.
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
......@@ -403,6 +471,7 @@ class ColumnParallelLinear(torch.nn.Module):
use_cpu_initialization:
gradient_accumulation_fusion:
accumulation_in_fp16:
sequence_parallel_enabled:
"""
def __init__(
......@@ -421,8 +490,9 @@ class ColumnParallelLinear(torch.nn.Module):
use_cpu_initialization=False,
gradient_accumulation_fusion=False,
accumulation_in_fp16: bool = False,
sequence_parallel_enabled: bool = False,
):
super(ColumnParallelLinear, self).__init__()
super().__init__()
# Keep input parameters
self.input_size = input_size
......@@ -439,9 +509,7 @@ class ColumnParallelLinear(torch.nn.Module):
# Initialize weight.
if use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.output_size_per_partition, self.input_size, dtype=params_dtype
)
torch.empty(self.output_size_per_partition, self.input_size, dtype=params_dtype)
)
self.master_weight = _initialize_affine_weight_cpu(
self.weight,
......@@ -463,15 +531,11 @@ class ColumnParallelLinear(torch.nn.Module):
dtype=params_dtype,
)
)
_initialize_affine_weight_gpu(
self.weight, init_method, partition_dim=0, stride=stride
)
_initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=stride)
if bias:
if use_cpu_initialization:
self.bias = Parameter(
torch.empty(self.output_size_per_partition, dtype=params_dtype)
)
self.bias = Parameter(torch.empty(self.output_size_per_partition, dtype=params_dtype))
else:
self.bias = Parameter(
torch.empty(
......@@ -490,14 +554,19 @@ class ColumnParallelLinear(torch.nn.Module):
self.async_tensor_model_parallel_allreduce = (
not no_async_tensor_model_parallel_allreduce and world_size > 1
)
if sequence_parallel_enabled:
if world_size <= 1:
warnings.warn(
f"`sequence_parallel_enabled` is set to `True`, but got world_size of {world_size}"
)
# sequence_parallel_enabled = False
self.sequence_parallel_enabled = sequence_parallel_enabled
if gradient_accumulation_fusion:
if not _grad_accum_fusion_available:
# Basically, apex.transformer module users are expected to install APEX's
# `--cpp_ext` and `--cuda_ext`. The example installation command is as follows:
# `pip install --global-option="--cpp_ext" --global-option="--cuda_ext ."
# at the root of APEX repository.
import warnings
warnings.warn(
"`gradient_accumulation_fusion` is set to `True` but "
"the custom CUDA extension of `fused_weight_gradient_mlp_cuda` module not "
......@@ -507,30 +576,45 @@ class ColumnParallelLinear(torch.nn.Module):
gradient_accumulation_fusion = False
self.gradient_accumulation_fusion = gradient_accumulation_fusion
if self.async_tensor_model_parallel_allreduce and self.sequence_parallel_enabled:
raise RuntimeError("`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` cannot be enabled at the same time.")
self._forward_impl = (
linear_with_grad_accumulation_and_async_allreduce_in16bit
if accumulation_in_fp16
else linear_with_grad_accumulation_and_async_allreduce
)
def forward(self, input_):
def forward(self, input_: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Forward of ColumnParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
Returns:
- output
- bias
"""
bias = self.bias if not self.skip_bias_add else None
if not self.async_tensor_model_parallel_allreduce:
# Set up backprop all-reduce.
input_parallel = copy_to_tensor_model_parallel_region(input_)
else:
if self.async_tensor_model_parallel_allreduce or self.sequence_parallel_enabled:
input_parallel = input_
else:
input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = self._forward_impl(
input_parallel,
self.weight,
bias,
self.gradient_accumulation_fusion,
self.async_tensor_model_parallel_allreduce,
input=input_parallel,
weight=self.weight,
bias=bias,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=self.async_tensor_model_parallel_allreduce,
sequence_parallel_enabled=self.sequence_parallel_enabled,
)
if self.gather_output:
# All-gather across the partitions.
assert not self.sequence_parallel_enabled
output = gather_from_tensor_model_parallel_region(output_parallel)
else:
output = output_parallel
......@@ -550,6 +634,11 @@ class RowParallelLinear(torch.nn.Module):
| . |
| A_p |
- -
.. note::
Input is supposed to be three dimensional and each dimension
is expected to be sequence, batch, and hidden feature, respectively.
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
......@@ -566,6 +655,12 @@ class RowParallelLinear(torch.nn.Module):
skip_bias_add: This was added to enable performance optimization where bias
can be fused with other elementwise operations. We skip
adding bias but instead return it.
Keyword Arguments:
params_dtype:
use_cpu_initialization:
gradient_accumulation_fusion:
accumulation_in_fp16:
sequence_parallel_enabled:
"""
def __init__(
......@@ -581,8 +676,11 @@ class RowParallelLinear(torch.nn.Module):
*,
params_dtype=torch.float32,
use_cpu_initialization=False,
gradient_accumulation_fusion=False,
accumulation_in_fp16: bool = False,
sequence_parallel_enabled: bool = False,
):
super(RowParallelLinear, self).__init__()
super().__init__()
# Keep input parameters
self.input_size = input_size
......@@ -592,6 +690,10 @@ class RowParallelLinear(torch.nn.Module):
world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add
self.gradient_accumulation_fusion = gradient_accumulation_fusion
self.sequence_parallel_enabled = sequence_parallel_enabled
if self.sequence_parallel_enabled and not self.input_is_parallel:
raise RuntimeError("To enable `sequence_parallel_enabled`, `input_is_parallel` must be `True`")
# as an argument to this function?
# Parameters.
......@@ -641,19 +743,46 @@ class RowParallelLinear(torch.nn.Module):
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
setattr(self.bias, "sequence_parallel_enabled", sequence_parallel_enabled)
else:
self.register_parameter("bias", None)
def forward(self, input_):
self._forward_impl = (
linear_with_grad_accumulation_and_async_allreduce_in16bit
if accumulation_in_fp16
else linear_with_grad_accumulation_and_async_allreduce
)
def forward(self, input_: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Forward of RowParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
Returns:
- output
- bias
"""
# Set up backprop all-reduce.
if self.input_is_parallel:
input_parallel = input_
else:
assert not self.sequence_parallel_enabled
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight)
output_parallel = self._forward_impl(
input=input_parallel,
weight=self.weight,
bias=None,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=False,
sequence_parallel_enabled=False,
)
# All-reduce across all the partitions.
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if self.sequence_parallel_enabled:
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_
output_bias = None
......
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -20,7 +20,7 @@ from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from apex.transformer.tensor_parallel.utils import split_tensor_along_last_dim
def _reduce(input_):
def _reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
......@@ -33,7 +33,7 @@ def _reduce(input_):
return input_
def _split(input_):
def _split_along_last_dim(input_: torch.Tensor) -> torch.Tensor:
"""Split the tensor along its last dimension and keep the
corresponding slice."""
......@@ -52,8 +52,24 @@ def _split(input_):
return output
def _gather(input_):
"""Gather tensors and concatinate along the last dimension."""
def _split_along_first_dim(input_: torch.Tensor) -> torch.Tensor:
"""Split the tensor along its first dimension and keep the corresponding slice."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU for tensor model parallel.
if world_size == 1:
return input_
# Split along first dimension.
dim_size = input_.size(0)
assert dim_size % world_size == 0
local_dim_size = dim_size // world_size
dim_offset = get_tensor_model_parallel_rank() * local_dim_size
output = input_[dim_offset:dim_offset + local_dim_size].contiguous()
return output
def _gather_along_last_dim(input_: torch.Tensor) -> torch.Tensor:
"""Gather tensors and concatenate along the last dimension."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
......@@ -76,9 +92,57 @@ def _gather(input_):
return output
def _gather_along_first_dim(input_: torch.Tensor) -> torch.Tensor:
"""Gather tensors and concatenate along the first dimension."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
shape = list(input_.shape)
shape[0] *= world_size
output = torch.empty(shape, dtype=input_.dtype, device=torch.cuda.current_device())
# Original implementation uses `_all_gather_base` as follows.
# Deliberately keep the comment-out for reference because
# I'd love to switch to this API once this gets public/stable.
# torch.distributed._all_gather_base(output, input_.contiguous(), group=get_tensor_model_parallel_group())
torch.distributed.all_gather(
list(output.chunk(world_size)),
input_.contiguous(),
group=get_tensor_model_parallel_group(),
)
return output
def _reduce_scatter_along_first_dim(input_: torch.Tensor) -> torch.Tensor:
"""Reduce-scatter the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
shape = list(input_.shape)
assert shape[0] % world_size == 0
shape[0] //= world_size
output = torch.empty(shape, dtype=input_.dtype, device=torch.cuda.current_device())
# Original implementation uses `_reduce_scatter_base` as follows.
# Deliberately keep the comment-out for reference because
# I'd love to switch to this API once this gets public/stable.
# torch.distributed._reduce_scatter_base(output, input_.contiguous(), group=get_tensor_model_parallel_group())
torch.distributed.reduce_scatter(
output,
list(input_.contiguous().chunk(world_size)),
group=get_tensor_model_parallel_group(),
)
return output
class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""
"""Pass the input to the tensor model parallel region."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@staticmethod
def symbolic(graph, input_):
return input_
......@@ -93,8 +157,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
"""All-reduce the input from the tensor model parallel region."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@staticmethod
def symbolic(graph, input_):
return _reduce(input_)
......@@ -111,33 +177,95 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@staticmethod
def symbolic(graph, input_):
return _split(input_)
return _split_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _split(input_)
return _split_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output)
return _gather_along_last_dim(grad_output)
class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
"""Gather the input from tensor model parallel region and concatenate."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@staticmethod
def symbolic(graph, input_):
return _gather_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _gather_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _split_along_last_dim(grad_output)
class _ScatterToSequenceParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chunk to the rank."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@staticmethod
def symbolic(graph, input_):
return _gather(input_)
return _split_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _gather(input_)
return _split_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output)
return _gather_along_first_dim(grad_output)
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatenate."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@staticmethod
def symbolic(graph, input_, to_model_parallel: bool = True):
return _gather_along_first_dim(input_)
@staticmethod
def forward(ctx, input_, to_model_parallel: bool = True):
ctx.to_model_parallel = to_model_parallel
return _gather_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
if ctx.to_model_parallel:
return _reduce_scatter_along_first_dim(grad_output), None
else:
return _split_along_first_dim(grad_output), None
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
"""Reduce scatter the input from the sequence parallel region and concatenate."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@staticmethod
def symbolic(graph, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
# -----------------
......@@ -145,17 +273,40 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
# -----------------
def copy_to_tensor_model_parallel_region(input_):
def copy_to_tensor_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
return _CopyToModelParallelRegion.apply(input_)
def reduce_from_tensor_model_parallel_region(input_):
def reduce_from_tensor_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_tensor_model_parallel_region(input_):
def scatter_to_tensor_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
return _ScatterToModelParallelRegion.apply(input_)
def gather_from_tensor_model_parallel_region(input_):
def gather_from_tensor_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
return _GatherFromModelParallelRegion.apply(input_)
def scatter_to_sequence_parallel_region(input_: torch.Tensor) -> torch.Tensor:
return _ScatterToSequenceParallelRegion.apply(input_)
def gather_from_sequence_parallel_region(input_: torch.Tensor, to_model_parallel: bool = True) -> torch.Tensor:
return _GatherFromSequenceParallelRegion.apply(input_, to_model_parallel)
def reduce_scatter_to_sequence_parallel_region(input_: torch.Tensor) -> torch.Tensor:
return _ReduceScatterToSequenceParallelRegion.apply(input_)
__all__ = [
"copy_to_tensor_model_parallel_region",
"reduce_from_tensor_model_parallel_region",
"scatter_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region",
"gather_from_sequence_parallel_region",
"reduce_scatter_to_sequence_parallel_region",
]
......@@ -13,6 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(mkozuki): Remove this file as Megatron-LM seems to have done so.
import torch
......
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,9 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO (mkozuki): Audit this file.
# I don't think some functions strongly relate to `random` in tensor_parallel.
# Rather, some functions are mainly for gradient checkpointing (torch.utils.checkpoint).
# NOTE(mkozuki): This file is based on megatron-lm/mpu/random.py with some differences:
# - Not using "viewless" tensor:
# - _kernel_make_viewless_tensor
# - MakeViewlessTensor
# - make_viewless_tensor
# - assert_viewless_tensor
# - safely_set_viewless_tensor_data
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
......@@ -35,13 +39,12 @@ from apex.transformer.utils import gather_split_1d_tensor
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = "model-parallel-rng"
# TODO(mkozuki): Remove `_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER` as megatron-lm doesn't seem to use.
# Whether apply model parallelism to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None
# TODO (mkozuki): Consider the possibility of removing `tensor_model_parallel_size`,
# `get_tensor_model_parallel_world_size()` might be alternative.
# TODO(mkozuki): Remove `init_checkpointed_activations_memory_buffer` as megatron-lm doesn't seem to use.
def init_checkpointed_activations_memory_buffer(
micro_batch_size,
max_position_embeddings,
......@@ -77,6 +80,7 @@ def init_checkpointed_activations_memory_buffer(
)
# TODO(mkozuki): Remove `reset_checkpointed_activations_memory_buffer` as megatron-lm doesn't seem to use.
def reset_checkpointed_activations_memory_buffer():
"""Reset the memory used for checkpointing."""
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
......@@ -86,7 +90,7 @@ def reset_checkpointed_activations_memory_buffer():
def _set_cuda_rng_state(new_state, device=-1):
"""Sets the random number generator state of the current GPU.
Argumentss:
Arguments:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
......@@ -239,8 +243,9 @@ class CheckpointFunction(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, run_function, *args):
def forward(ctx, run_function, distribute_saved_activations, *args):
ctx.run_function = run_function
ctx.distribute_saved_activations = distribute_saved_activations
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
......@@ -252,10 +257,8 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
ctx.input_0_shape = args[0].data.shape
args[0].data = split_tensor_into_1d_equal_chunks(args[0].data)
args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add(args[0].data)
if ctx.distribute_saved_activations:
ctx.input_0_shape = args[0].shape
# Store everything.
ctx.save_for_backward(*args)
......@@ -269,9 +272,6 @@ class CheckpointFunction(torch.autograd.Function):
"please use .backward() if possible"
)
inputs = ctx.saved_tensors
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
inputs[0].data = gather_split_1d_tensor(inputs[0].data)
inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
......@@ -300,10 +300,12 @@ class CheckpointFunction(torch.autograd.Function):
inp.grad if isinstance(inp, torch.Tensor) else inp
for inp in detached_inputs
)
return (None,) + grads
return (None, None) + grads
def checkpoint(function, *args):
# NOTE(mkozuki): It doesn't look like `distribute_saved_activations` is used in apex.transformer
# but I added this change to reduce the superficial difference from Megatron-LM.
def checkpoint(function, distribute_saved_activations, *args):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
return CheckpointFunction.apply(function, *args)
return CheckpointFunction.apply(function, distribute_saved_activations, *args)
......@@ -12,12 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Sequence
import torch
from apex.transformer.utils import divide
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]:
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
......@@ -44,14 +50,14 @@ class VocabUtility:
@staticmethod
def vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size
):
per_partition_vocab_size: int, rank, world_size: int
) -> Sequence[int]:
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
@staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size
......
......@@ -39,9 +39,13 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_data_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_biencoder_args(parser)
parser = _add_vit_args(parser)
parser = _add_vision_args(parser)
parser = _add_logging_args(parser)
# NOTE(mkozuki): This option is added to investigate the potential of `torch.autograd.graph.save_on_cpu()`.
# ref: https://pytorch.org/docs/stable/autograd.html#torch.autograd.graph.save_on_cpu.
parser.add_argument('--cpu-offload', action='store_true', default=False, help='Turns on CPU offloading')
# Custom arguments.
if extra_args_provider is not None:
parser = extra_args_provider(parser)
......@@ -65,6 +69,11 @@ def parse_args(extra_args_provider=None, defaults={},
args.pipeline_model_parallel_size = min(
args.pipeline_model_parallel_size,
(args.world_size // args.tensor_model_parallel_size))
args.transformer_pipeline_model_parallel_size = (
args.pipeline_model_parallel_size - 1
if args.standalone_embedding_stage else
args.pipeline_model_parallel_size
)
# Checks.
model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size
......@@ -98,13 +107,18 @@ def parse_args(extra_args_provider=None, defaults={},
'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size
if args.checkpoint_activations:
args.activations_checkpoint_method = 'uniform'
args.recompute_granularity = 'full'
args.recompute_method = 'uniform'
if args.rank == 0:
print('--checkpoint-activations is no longer valid, '
'use --activation-checkpoint-method instead. '
'Defaulting to activation-checkpoint-method=uniform.')
'use --recompute-granularity and --recompute-method instead. '
'Defaulting to recompute-granularity=full and recompute-method=uniform.')
del args.checkpoint_activations
if args.recompute_activations:
args.recompute_granularity = 'selective'
del args.recompute_activations
# Set input defaults.
for key in defaults:
# For default to be valid, it should not be provided in the
......@@ -166,6 +180,14 @@ def parse_args(extra_args_provider=None, defaults={},
if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local'
assert args.use_contiguous_buffers_in_local_ddp
else:
if args.gradient_accumulation_fusion:
args.gradient_accumulation_fusion = False
if args.rank == 0:
print('Gradient accumulation fusion to linear layer weight '
'gradient computation is supported only with fp32 '
'gradient accumulation. Setting gradient_accumulation_fusion '
'to False', flush=True)
# For torch DDP, we do not use contiguous buffer
if args.DDP_impl == 'torch':
......@@ -244,17 +266,51 @@ def parse_args(extra_args_provider=None, defaults={},
if args.fp32_residual_connection:
assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing.
if args.distribute_checkpointed_activations:
if args.weight_decay_incr_style == 'constant':
assert args.start_weight_decay is None
assert args.end_weight_decay is None
args.start_weight_decay = args.weight_decay
args.end_weight_decay = args.weight_decay
else:
assert args.start_weight_decay is not None
assert args.end_weight_decay is not None
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
# Persistent fused layer norm.
if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11):
args.no_persist_layer_norm = True
if args.rank == 0:
print('Persistent fused layer norm kernel is supported from '
'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
'Defaulting to no_persist_layer_norm=True')
# Activation recomputing.
if args.distribute_saved_activations:
assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'checkpointed activations only across tensor model ' \
'recomputed activations only across tensor model ' \
'parallel groups'
assert args.activations_checkpoint_method is not None, \
'for distribute-checkpointed-activations to work you '\
'need to use a activation-checkpoint method '
assert args.num_layers_per_virtual_pipeline_stage is None, \
'currently distrobuted checkpoint activations only supported for ' \
'nointerleaved pipeline parallelism'
assert args.recompute_granularity == 'full', \
'distributed recompute activations is only '\
'application to full recompute granularity'
assert args.recompute_method is not None, \
'for distributed recompute activations to work you '\
'need to use a recompute method '
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \
'distributed recompute activations are supported for pytorch ' \
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
if args.recompute_granularity == 'selective':
assert args.recompute_method is None, \
'recompute method is not yet supported for ' \
'selective recomputing granularity'
# disable async_tensor_model_parallel_allreduce when
# model parallel memory optimization is enabled
if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False
_print_args(args)
return args
......@@ -279,6 +335,18 @@ def _check_arg_is_not_none(args, arg):
assert getattr(args, arg) is not None, '{} argument is None'.format(arg)
def _add_inference_args(parser):
group = parser.add_argument_group(title='inference')
group.add_argument('--inference-batch-times-seqlen-threshold',
type=int, default=512,
help='During inference, if batch-size times '
'sequence-length is smaller than this threshold '
'then we will not use pipelining, otherwise we will.')
return parser
def _add_network_size_args(parser):
group = parser.add_argument_group(title='network size')
......@@ -318,6 +386,8 @@ def _add_network_size_args(parser):
group.add_argument('--bert-no-binary-head', action='store_false',
help='Disable BERT binary head.',
dest='bert_binary_head')
group.add_argument('--num-experts', type=int, default=None,
help='Number of Experts in Switch Transformer (None means no Switch)')
return parser
......@@ -354,6 +424,9 @@ def _add_logging_args(parser):
group.add_argument('--log-memory-to-tensorboard',
action='store_true',
help='Enable memory logging to tensorboard.')
group.add_argument('--log-world-size-to-tensorboard',
action='store_true',
help='Enable world size logging to tensorboard.')
return parser
......@@ -367,6 +440,13 @@ def _add_regularization_args(parser):
help='Dropout probability for hidden state transformer.')
group.add_argument('--weight-decay', type=float, default=0.01,
help='Weight decay coefficient for L2 regularization.')
group.add_argument('--start-weight-decay', type=float,
help='Initial weight decay coefficient for L2 regularization.')
group.add_argument('--end-weight-decay', type=float,
help='End of run weight decay coefficient for L2 regularization.')
group.add_argument('--weight-decay-incr-style', type=str, default='constant',
choices=['constant', 'linear', 'cosine'],
help='Weight decay increment function.')
group.add_argument('--clip-grad', type=float, default=1.0,
help='Gradient clipping based on global L2 norm.')
group.add_argument('--adam-beta1', type=float, default=0.9,
......@@ -413,27 +493,40 @@ def _add_training_args(parser):
' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.')
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
group.add_argument('--recompute-activations', action='store_true',
help='recompute activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--distribute-checkpointed-activations',
group.add_argument('--recompute-granularity', type=str, default=None,
choices=['full', 'selective'],
help='Checkpoint activations to allow for training '
'with larger models, sequences, and batch sizes. '
'It is supported at two granularities 1) full: '
'whole transformer layer is recomputed, '
'2) selective: core attention part of the transformer '
'layer is recomputed.')
group.add_argument('--distribute-saved-activations',
action='store_true',
help='If set, distribute checkpointed activations '
help='If set, distribute recomputed activations '
'across model parallel group.')
group.add_argument('--activations-checkpoint-method', type=str, default=None,
group.add_argument('--recompute-method', type=str, default=None,
choices=['uniform', 'block'],
help='1) uniform: uniformly divide the total number of '
'Transformer layers and checkpoint the input activation of '
'each divided chunk, '
'2) checkpoint the input activations of only a set number of '
'Transformer layers and recompute the input activation of '
'each divided chunk at specified granularity, '
'2) recompute the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'rest without any checkpointing'
'default) do not apply activations checkpoint to any layers')
group.add_argument('--activations-checkpoint-num-layers', type=int, default=1,
'rest without any recomputing at specified granularity'
'default) do not apply activations recompute to any layers')
group.add_argument('--recompute-num-layers', type=int, default=1,
help='1) uniform: the number of Transformer layers in each '
'uniformly divided checkpoint unit, '
'uniformly divided recompute unit, '
'2) block: the number of individual Transformer layers '
'to checkpoint within each pipeline stage.')
'to recompute within each pipeline stage.')
# deprecated
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all '
'training runs. Note that either train-iters or '
......@@ -472,7 +565,20 @@ def _add_training_args(parser):
action='store_true',
help='Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.')
'gradient compuation of a column-linear layer.',
dest='async_tensor_model_parallel_allreduce')
group.add_argument('--no-persist-layer-norm', action='store_true',
help='Disable using persistent fused layer norm kernel. '
'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden '
'size is supported.')
group.add_argument('--sequence-parallel', action='store_true',
help='Enable sequence parallel optimization.')
group.add_argument('--no-gradient-accumulation-fusion',
action='store_false',
help='Disable fusing gradient accumulation to weight '
'gradient computation of linear layers',
dest='gradient_accumulation_fusion')
return parser
......@@ -640,13 +746,16 @@ def _add_distributed_args(parser):
group.add_argument('--use-cpu-initialization', action='store_true',
default=None, help='If set, affine parallel weights '
'initialization uses CPU' )
group.add_argument('--cpu-offload', action='store_true',
default=False, help='Turns on CPU offloading')
group.add_argument('--empty-unused-memory-level', default=0, type=int,
choices=[0, 1, 2],
help='Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.')
group.add_argument('--standalone-embedding-stage', action='store_true',
default=False, help='If set, *input* embedding layer '
'is placed on its own pipeline stage, without any '
'transformer layers. (For T5, this flag currently only '
'affects the encoder embedding.)')
return parser
......@@ -793,16 +902,70 @@ def _add_biencoder_args(parser):
return parser
def _add_vit_args(parser):
group = parser.add_argument_group(title="vit")
def _add_vision_args(parser):
group = parser.add_argument_group(title="vision")
# general vision arguments
group.add_argument('--num-classes', type=int, default=1000,
help='num of classes in vision classificaiton task')
group.add_argument('--img-dim', type=int, default=224,
help='Image size for vision classification task')
group.add_argument('--img-h', type=int, default=224,
help='Image height for vision classification task')
group.add_argument('--img-w', type=int, default=224,
help='Image height for vision classification task')
group.add_argument('--num-channels', type=int, default=3,
help='Number of channels in input image data')
group.add_argument('--patch-dim', type=int, default=16,
help='patch dimension used in vit')
help='patch dimension')
group.add_argument('--classes-fraction', type=float, default=1.0,
help='training with fraction of classes.')
group.add_argument('--data-per-class-fraction', type=float, default=1.0,
help='training with fraction of data per class.')
group.add_argument('--no-data-sharding', action='store_false',
help='Disable data sharding.',
dest='data_sharding')
group.add_argument('--head-lr-mult', type=float, default=1.0,
help='learning rate multiplier for head during finetuning')
# pretraining type and backbone selection`
group.add_argument('--vision-pretraining', action='store_true',
help='flag to indicate vision pretraining')
group.add_argument('--vision-pretraining-type', type=str, default='classify',
choices=['classify', 'inpaint', 'dino'],
help='pretraining objectives')
group.add_argument('--vision-backbone-type', type=str, default='vit',
choices=['vit', 'mit', 'swin'],
help='backbone types types')
group.add_argument('--swin-backbone-type', type=str, default='tiny',
choices=['tiny', 'base', 'h3'],
help='pretraining objectives')
# inpainting arguments
group.add_argument('--mask-type', type=str, default='random',
choices=['random', 'row'],
help='mask types')
group.add_argument('--mask-factor', type=float, default=1.0,
help='mask size scaling parameter')
# dino arguments
group.add_argument('--iter-per-epoch', type=int, default=1250,
help='iterations per epoch')
group.add_argument('--dino-local-img-size', type=int, default=96,
help='Image size for vision classification task')
group.add_argument('--dino-local-crops-number', type=int, default=10,
help='Number of local crops')
group.add_argument('--dino-head-hidden-size', type=int, default=2048,
help='Hidden dimension size in dino head')
group.add_argument('--dino-bottleneck-size', type=int, default=256,
help='Bottle neck dimension in dino head ')
group.add_argument('--dino-freeze-last-layer', type=float, default=1,
help='Freezing last layer weights')
group.add_argument('--dino-norm-last-layer', action='store_true',
help='Disable Norm in last layer.')
group.add_argument('--dino-warmup-teacher-temp', type=float, default=0.04,
help='warump teacher temperature')
group.add_argument('--dino-teacher-temp', type=float, default=0.07,
help='teacher temperature')
group.add_argument('--dino-warmup-teacher-temp-epochs', type=int, default=30,
help='warmup teacher temperaure epochs')
return parser
......@@ -12,19 +12,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import datetime
import os
import random
from typing import Optional, Union, List
from typing import Optional, Union, List, Tuple, Callable, Dict
import numpy
import torch
import torch.nn as nn
from apex import transformer
from apex.transformer.tensor_parallel import(
ColumnParallelLinear,
RowParallelLinear,
scatter_to_sequence_parallel_region,
)
from apex.transformer.pipeline_parallel.utils import (
average_losses_across_data_parallel_group,
)
from apex.transformer.pipeline_parallel.schedules.common import (
Batch,
)
from apex.transformer.testing import global_vars
......@@ -45,7 +54,10 @@ class MyLayer(nn.Module):
class MyModel(nn.Module):
def __init__(
self, hidden_size: int, pre_process: bool = False, post_process: bool = False
self,
hidden_size: int, pre_process: bool = False, post_process: bool = False,
*,
add_encoder: bool = False, add_decoder: bool = False,
) -> None:
super().__init__()
self.pre_process = pre_process
......@@ -68,8 +80,105 @@ class MyModel(nn.Module):
return self.layer(self.input_tensor)
def model_provider_func(hidden_size, pre_process, post_process) -> MyModel:
return MyModel(hidden_size, pre_process, post_process)
class ToyParallelMLP(nn.Module):
def __init__(
self,
hidden_size: int, pre_process: bool = False, post_process: bool = False,
*,
sequence_parallel_enabled: bool = False,
# TODO(mkozuki): Support these two?
add_encoder: bool = False, add_decoder: bool = False,
) -> None:
super().__init__()
self.pre_process = pre_process
self.post_process = post_process
self.sequence_parallel_enabled = sequence_parallel_enabled
ffn_hidden_size = 4 * hidden_size
self.dense_h_to_4h = ColumnParallelLinear(
hidden_size,
ffn_hidden_size,
gather_output=False,
# init_method=init_method,
skip_bias_add=True,
# use_cpu_initialization=use_cpu_initialization,
bias=True,
sequence_parallel_enabled=sequence_parallel_enabled,
no_async_tensor_model_parallel_allreduce=True,
)
self.dense_4h_to_h = RowParallelLinear(
ffn_hidden_size,
hidden_size,
input_is_parallel=True,
# init_method=output_layer_init_method,
skip_bias_add=False,
# use_cpu_initialization=use_cpu_initialization,
bias=True,
sequence_parallel_enabled=sequence_parallel_enabled,
)
self.activation_func = torch.nn.GELU()
def set_input_tensor(
self,
input_tensor: Union[torch.Tensor, List[torch.Tensor]],
) -> None:
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
self.input_tensor = input_tensor[0]
def forward(
self,
x: Optional[torch.Tensor],
) -> torch.Tensor:
"""Forward of Simplified ParallelMLP.
Args:
x: :obj:`None` if pipeline rank != pippeline first rank. When :obj:`None`,
`self.input_tensor` is taken care of by `forward_step` defined in
apex/transformer/pipeline_parallel/schedules/common.py
"""
# [s, b, h]
if self.input_tensor is None:
input = x
else:
input = self.input_tensor
intermediate_parallel, bias_parallel = self.dense_h_to_4h(input)
if bias_parallel is not None:
intermediate_parallel += bias_parallel
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output
def model_provider_func(
hidden_size: int,
pre_process: bool,
post_process: bool,
*,
add_encoder: bool = False,
add_decoder: bool = False) -> MyModel:
return MyModel(hidden_size, pre_process, post_process, add_encoder=add_encoder, add_decoder=add_decoder)
def mlp_provider_func(
hidden_size: int,
pre_process: bool,
post_process: bool,
*,
add_encoder: bool = False,
add_decoder: bool = False,
sequence_parallel_enabled: bool = False,
) -> ToyParallelMLP:
return ToyParallelMLP(
hidden_size,
pre_process,
post_process,
add_encoder=add_encoder,
add_decoder=add_decoder,
sequence_parallel_enabled=sequence_parallel_enabled,
)
def process_batch(batch):
......@@ -94,6 +203,33 @@ def fwd_step_func(batch, model):
return y, loss_func
@dataclass(frozen=True)
class ToyParallelMLPFwdBwdStepFunc:
sequence_parallel_enabled: bool
def __call__(
self,
batch: Batch,
model: torch.nn.Module,
) -> Tuple[torch.Tensor, Callable[[torch.Tensor], Tuple[torch.Tensor, Dict[str, torch.Tensor]]]]:
x = batch[0] if isinstance(batch, list) else batch
if isinstance(x, torch.Tensor):
x = x.transpose(0, 1).contiguous()
if self.sequence_parallel_enabled:
x = scatter_to_sequence_parallel_region(x)
y = model(x)
# note (mkozuki): I don't think this function is nice but I do think this is enough for now
# just to check the sanity of ported pipeline functions.
def loss_func(x):
loss = torch.sum(x)
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"avg": averaged_loss}
return y, loss_func
class IdentityLayer(torch.nn.Module):
def __init__(self, size, scale=1.0):
super(IdentityLayer, self).__init__()
......
import torch
import contextlib
from apex.normalization import FusedLayerNorm as LayerNorm
import torch
from apex.transformer import tensor_parallel
from apex.transformer.enums import AttnMaskType
from apex.transformer.enums import ModelType
from apex.transformer.layers import FusedLayerNorm as LayerNorm
from apex.transformer.testing.global_vars import get_args
from .standalone_gpt import get_language_model, get_linear_layer, init_method_normal, parallel_lm_logits, scaled_init_method_normal
from .standalone_gpt import MegatronModule
from apex.transformer.testing.standalone_transformer_lm import (
MegatronModule,
get_language_model,
get_linear_layer,
init_method_normal,
scaled_init_method_normal,
parallel_lm_logits,
)
def bert_extended_attention_mask(attention_mask):
# We create a 3D attention mask from a 2D tensor mask.
......@@ -23,6 +33,7 @@ def bert_extended_attention_mask(attention_mask):
return extended_attention_mask
def bert_position_ids(token_ids):
# Create position ids
seq_length = token_ids.size(1)
......@@ -32,6 +43,7 @@ def bert_position_ids(token_ids):
return position_ids
class BertLMHead(MegatronModule):
"""Masked LM head for Bert
......@@ -56,13 +68,18 @@ class BertLMHead(MegatronModule):
self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
setattr(self.dense.weight, 'sequence_parallel', args.sequence_parallel)
setattr(self.dense.bias, 'sequence_parallel', args.sequence_parallel)
self.layernorm = LayerNorm(
hidden_size, eps=layernorm_epsilon, sequence_parallel_enabled=args.sequence_parallel)
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
self.gelu = openai_gelu
elif args.onnx_safe:
self.gelu = erf_gelu
def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states)
hidden_states = self.gelu(hidden_states)
......@@ -73,6 +90,7 @@ class BertLMHead(MegatronModule):
bias=self.bias)
return output
def post_language_model_processing(lm_output, pooled_output,
lm_head, binary_head,
lm_labels,
......@@ -87,8 +105,12 @@ def post_language_model_processing(lm_output, pooled_output,
binary_logits = binary_head(pooled_output)
if lm_labels is None:
return lm_logits, binary_logits
# [s b h] => [b s h]
return lm_logits.transpose(0, 1).contiguous(), binary_logits
else:
# [b s] => [s b]
lm_labels = lm_labels.transpose(0, 1).contiguous()
# lm_logits: [s b h] lm_labels: [s b]
if fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
......@@ -116,7 +138,7 @@ class BertModel(MegatronModule):
self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
self.cpu_offload = cpu_offload
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
......@@ -142,13 +164,17 @@ class BertModel(MegatronModule):
init_method)
self._binary_head_key = 'binary_head'
self.forward_context = contextlib.nullcontext
if cpu_offload:
self.forward_context = torch.autograd.graph.save_on_cpu
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, bert_model_input, attention_mask,
tokentype_ids=None, lm_labels=None):
with torch.autograd.graph.save_on_cpu() if self.cpu_offload else contextlib.nullcontext():
with self.forward_context():
extended_attention_mask = bert_extended_attention_mask(attention_mask)
input_ids = bert_model_input
position_ids = bert_position_ids(input_ids)
......@@ -174,7 +200,7 @@ class BertModel(MegatronModule):
else:
return lm_output
# NOTE(mkozuki): This method is not maintained as apex only tests forward_backward with best effort.
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
......@@ -197,6 +223,7 @@ class BertModel(MegatronModule):
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_
# NOTE(mkozuki): This method is not maintained as apex only tests forward_backward with best effort.
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......@@ -213,6 +240,16 @@ class BertModel(MegatronModule):
self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict)
def bert_model_provider(pre_process=True, post_process=True, cpu_offload=False):
model = BertModel(num_tokentypes=0, add_binary_head=False, pre_process=pre_process, post_process=post_process, cpu_offload=cpu_offload)
args = get_args()
num_tokentypes = 2 if args.bert_binary_head else 0
model = BertModel(
num_tokentypes=num_tokentypes,
add_binary_head=args.bert_binary_head,
parallel_output=True,
pre_process=pre_process,
post_process=post_process,
cpu_offload=cpu_offload,
)
return model
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,1447 +12,67 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT-2 model."""
import enum
import math
import contextlib
import json
import torch
import torch.nn.functional as F
import apex.transformer.utils
from apex.normalization import FusedLayerNorm as LayerNorm
from apex.transformer.functional import FusedScaleMaskSoftmax
from apex.transformer.enums import AttnMaskType
from apex.transformer.enums import ModelType
from apex.transformer import tensor_parallel
from apex.transformer import parallel_state
from apex.transformer.testing.global_vars import get_args
from apex.transformer.enums import LayerType
from apex.transformer.enums import AttnType
from apex.transformer.enums import AttnMaskType
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff * g
class MegatronModule(torch.nn.Module):
"""Megatron specific extensions of torch Module with support
for pipelining."""
def __init__(self, share_word_embeddings=True):
super(MegatronModule, self).__init__()
self.share_word_embeddings = share_word_embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix="",
keep_vars=False):
"""Use this function to override the state dict for
saving checkpoints."""
return self.state_dict(destination, prefix, keep_vars)
def word_embeddings_weight(self):
if not parallel_state.is_pipeline_last_stage(ignore_virtual=True) or \
parallel_state.get_pipeline_model_parallel_world_size() == 1:
return self.language_model.embedding.word_embeddings.weight
else:
if not self.share_word_embeddings:
raise Exception("word_embeddings_weight() called for last "
"stage, but share_word_embeddings is false")
return self.word_embeddings.weight
def initialize_word_embeddings(self, init_method_normal):
args = get_args()
if not self.share_word_embeddings:
raise Exception("initialize_word_embeddings() was called but "
"share_word_embeddings is false")
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. Nothing to do if we aren't
# using pipeline parallelism.
if args.pipeline_model_parallel_size == 1:
return
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
if parallel_state.is_pipeline_last_stage():
assert not parallel_state.is_pipeline_first_stage()
self._word_embeddings_for_head_key = "word_embeddings_for_head"
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std),
use_cpu_initialization=args.use_cpu_initialization)
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
# Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule.
if not parallel_state.is_pipeline_first_stage(ignore_virtual=True) and \
not parallel_state.is_pipeline_last_stage(ignore_virtual=True) and \
parallel_state.is_rank_in_embedding_group():
self.language_model.embedding.zero_parameters()
# Ensure that first and last stages have the same initial parameter
# values.
if torch.distributed.is_initialized():
if parallel_state.is_rank_in_embedding_group():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=parallel_state.get_embedding_group())
# All-reduce other embeddings as well as necessary. The last stage
# does not have these other embeddings, so just create placeholder
# tensors of the right shape with all zeros.
# NOTE: We don't currently support T5 with the interleaved schedule.
if args.pipeline_model_parallel_split_rank is not None:
# TODO: Support tokentype embedding.
dimensions = (args.max_position_embeddings, args.hidden_size)
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
position_embeddings = torch.nn.Embedding(*dimensions).cuda()
position_embeddings.weight.data.fill_(0)
else:
self.language_model.embedding.cuda()
position_embeddings = self.language_model.embedding.position_embeddings
torch.distributed.all_reduce(position_embeddings.weight.data,
group=parallel_state.get_embedding_group())
else:
print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong.")
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(bias, input)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply
def get_linear_layer(rows, columns, init_method):
"""Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns)
init_method(layer.weight)
with torch.no_grad():
layer.bias.zero_()
return layer
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
@torch.jit.script
def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
def openai_gelu(x):
return gelu_impl(x)
# This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
@torch.jit.script
def erf_gelu(x):
return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype) + torch.ones_like(x).to(dtype=x.dtype))
def init_method_normal(sigma):
"""Init method based on N(0, sigma)."""
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
return init_
def scaled_init_method_normal(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
class ParallelMLP(MegatronModule):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self, init_method, output_layer_init_method):
super().__init__()
args = get_args()
# Project to 4h.
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.ffn_hidden_size, gather_output=False, init_method=init_method, skip_bias_add=True,
use_cpu_initialization=args.use_cpu_initialization)
self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
if args.openai_gelu:
self.activation_func = openai_gelu
elif args.onnx_safe:
self.activation_func = erf_gelu
# Project back to h.
self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
args.ffn_hidden_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
use_cpu_initialization=args.use_cpu_initialization
)
def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
if self.bias_gelu_fusion:
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
else:
intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel)
# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def __init__(
self,
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding,
):
super().__init__()
args = get_args()
self.fp16 = args.fp16
self.bf16 = args.bf16
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = apex.transformer.utils.divide(projection_size, world_size)
self.hidden_size_per_attention_head = apex.transformer.utils.divide(projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = apex.transformer.utils.divide(args.num_attention_heads, world_size)
# Strided linear layer.
if attention_type == AttnType.self_attn:
self.query_key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size, 3 * projection_size, gather_output=False, init_method=init_method, use_cpu_initialization=args.use_cpu_initialization)
else:
assert attention_type == AttnType.cross_attn
self.query = tensor_parallel.ColumnParallelLinear(
args.hidden_size, projection_size, gather_output=False, init_method=init_method, use_cpu_initialization=args.use_cpu_initialization)
self.key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size, 2 * projection_size, gather_output=False, init_method=init_method, use_cpu_initialization=args.use_cpu_initialization)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16,
self.bf16,
self.attn_mask_type,
args.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff,
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
# Output.
self.dense = tensor_parallel.RowParallelLinear(
projection_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
use_cpu_initialization=args.use_cpu_initialization
)
# Inference key-value memory
self.inference_key_memory = None
self.inference_value_memory = None
self.inference_current_sequence_len = 0
def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=torch.cuda.current_device(),
)
def forward(
self,
hidden_states,
attention_mask,
encoder_output=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None,
):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if set_inference_key_value_memory:
assert inference_max_sequence_len and inference_max_sequence_len > 0
self.inference_key_memory = self._allocate_memory(inference_max_sequence_len, hidden_states.size(1))
self.inference_value_memory = self._allocate_memory(inference_max_sequence_len, hidden_states.size(1))
self.inference_current_sequence_len = 0
# Some consistency check.
if inference_max_sequence_len:
assert self.inference_current_sequence_len < self.inference_key_memory.size(0)
assert inference_max_sequence_len == self.inference_key_memory.size(0)
# This is added for safety. In case inference_max_sequence_len
# is not provided, make sure there is no potential memory left
# from previous inference.
if not inference_max_sequence_len:
self.inference_key_memory = None
self.inference_value_memory = None
# =====================
# Query, Key, and Value
# =====================
if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head,
)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
query_layer = query_layer.view(*new_tensor_shape)
# ===================================================
# Adjust key, value, and attention mask for inference
# ===================================================
if inference_max_sequence_len:
# Adjust the range variables.
start = self.inference_current_sequence_len
self.inference_current_sequence_len += key_layer.size(0)
end = self.inference_current_sequence_len
# Copy key and values.
self.inference_key_memory[start:end, ...] = key_layer
self.inference_value_memory[start:end, ...] = value_layer
key_layer = self.inference_key_memory[:end, ...]
value_layer = self.inference_value_memory[:end, ...]
# Adjust attention mask
attention_mask = attention_mask[..., start:end, :end]
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk]
matmul_result = torch.empty(
output_size[0] * output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device(),
)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
)
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
# =================
# Output. [sq, b, h]
# =================
output, bias = self.dense(context_layer)
return output, bias
@torch.jit.script
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = residual + out
return out
def get_bias_dropout_add(training):
def _bias_dropout_add(x, bias, residual, prob):
return bias_dropout_add(x, bias, residual, prob, training)
return _bias_dropout_add
@torch.jit.script
def bias_dropout_add_fused_train(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script
def bias_dropout_add_fused_inference(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, False)
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(
self,
init_method,
output_layer_init_method,
layer_number,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
):
args = get_args()
super().__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm = args.apply_residual_connection_post_layernorm
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Layernorm on the input data.
self.input_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon)
# Self attention.
self.self_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type,
)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
init_method, output_layer_init_method, layer_number, attention_type=AttnType.cross_attn
)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon)
# MLP
self.mlp = ParallelMLP(init_method, output_layer_init_method)
def forward(
self,
hidden_states,
attention_mask,
encoder_output=None,
enc_dec_attn_mask=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None,
):
# hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = self.self_attention(
layernorm_output,
attention_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len,
)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
from apex.transformer.testing.standalone_transformer_lm import MegatronModule
from apex.transformer.testing.standalone_transformer_lm import parallel_lm_logits
from apex.transformer.testing.standalone_transformer_lm import post_language_model_processing
from apex.transformer.testing.standalone_transformer_lm import get_language_model
from apex.transformer.testing.standalone_transformer_lm import init_method_normal
from apex.transformer.testing.standalone_transformer_lm import (
scaled_init_method_normal,
)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout
)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.layer_type == LayerType.decoder:
attention_output, attention_bias = self.inter_attention(
layernorm_output, enc_dec_attn_mask, encoder_output=encoder_output
)
# residual connection
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout
)
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(mlp_output, mlp_bias.expand_as(residual), residual, self.hidden_dropout)
return output
class ParallelTransformer(MegatronModule):
"""Transformer class."""
def __init__(
self,
init_method,
output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
pre_process=True,
post_process=True,
):
super().__init__()
args = get_args()
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
self.pre_process = pre_process
self.post_process = post_process
self.input_tensor = None
# Store activation checkpointing flag.
self.activations_checkpoint_method = args.activations_checkpoint_method
self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
num_layers = args.num_layers
# Number of layers.
assert (
num_layers % parallel_state.get_pipeline_model_parallel_world_size() == 0
), "num_layers must be divisible by pipeline_model_parallel_size"
self.num_layers = num_layers // parallel_state.get_pipeline_model_parallel_world_size()
# Transformer layers.
def build_layer(layer_number):
return ParallelTransformerLayer(
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
)
if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, (
"num_layers_per_stage must be divisible by " "virtual_pipeline_model_parallel_size"
)
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size
) + (parallel_state.get_pipeline_model_parallel_rank() * self.num_layers)
else:
# Each stage gets a contiguous set of layers.
offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList([build_layer(i + 1 + offset) for i in range(self.num_layers)])
if self.post_process:
# Final layer norm before output.
self.final_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon)
def _get_layer(self, layer_number):
return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
x_ = inputs[0]
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
return x_
return custom_forward
def distribute_checkpointed_activations_helper(layer_number):
"""Distribute checkpointed activations across the tensor model
Parallel ranks if the `distribute-checkpointed-activations
is on and either of the following conditions is met:
- it is not the first layer in the in the pipeline stage.
The first layer is used in the pipeline parallelism
and changing its shape throws error in the backward pass.
- we are at the first pipline stage so the input tensor is
not used in pipeline parallelism. Note that no pipeline
parallelism is a special case of this.
"""
not_first_layer_in_pipeline_stage = layer_number > 0
is_first_pipeline_stage = parallel_state.get_pipeline_model_parallel_rank() == 0
return self.distribute_checkpointed_activations and (
not_first_layer_in_pipeline_stage or is_first_pipeline_stage
)
if self.activations_checkpoint_method == "uniform":
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers),
distribute_checkpointed_activations_helper(l),
hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask,
)
l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == "block":
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.activations_checkpoint_num_layers:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + 1),
distribute_checkpointed_activations_helper(l),
hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask,
)
else:
hidden_states = custom(l, l + 1)(hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
raise ValueError("Invalid activation checkpoint method.")
return hidden_states
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(
self,
hidden_states,
attention_mask,
encoder_output=None,
enc_dec_attn_mask=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None,
):
# Checks.
if inference_max_sequence_len:
assert self.activations_checkpoint_method is None, "inference does not work with activation checkpointing"
if self.pre_process:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
hidden_states = hidden_states.transpose(0, 1).contiguous().float()
# Otherwise, leave it as is.
else:
hidden_states = hidden_states.transpose(0, 1).contiguous()
else:
# See set_input_tensor()
hidden_states = self.input_tensor
if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous()
if self.activations_checkpoint_method is not None:
hidden_states = self._checkpointed_forward(hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
for index in range(self.num_layers):
layer = self._get_layer(index)
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len,
)
# Final layer norm.
if self.post_process:
# Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states)
else:
output = hidden_states
return output
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
"""LM logits using word embedding weights."""
# Parallel logits.
input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
if bias is None:
logits_parallel = F.linear(input_parallel, word_embeddings_weight)
else:
logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias)
# Gather if needed.
if parallel_output:
return logits_parallel
return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(
num_tokentypes,
add_pooler,
encoder_attn_mask_type,
init_method=None,
scaled_init_method=None,
add_encoder=True,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
pre_process=True,
post_process=True,
):
"""Build language model and return along with the key to save."""
def gpt_model_provider(pre_process: bool = True, post_process: bool = True, cpu_offload: bool = False,) -> "GPTModel":
args = get_args()
if init_method is None:
init_method = init_method_normal(args.init_method_std)
if scaled_init_method is None:
scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
# Language model.
language_model = TransformerLanguageModel(
init_method,
scaled_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes,
add_encoder=add_encoder,
add_decoder=add_decoder,
decoder_attn_mask_type=decoder_attn_mask_type,
add_pooler=add_pooler,
model = GPTModel(
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process,
cpu_offload=args.cpu_offload,
)
# key used for checkpoints.
language_model_key = "language_model"
return language_model, language_model_key
class Pooler(MegatronModule):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Arguments:
hidden_size: hidden size
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def __init__(self, hidden_size, init_method):
super(Pooler, self).__init__()
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
pooled = hidden_states[:, sequence_index, :]
pooled = self.dense(pooled)
pooled = torch.tanh(pooled)
return pooled
class Embedding(MegatronModule):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(
self, hidden_size, vocab_size, max_sequence_length, embedding_dropout_prob, init_method, num_tokentypes=0
):
super().__init__()
self.hidden_size = hidden_size
self.init_method = init_method
self.num_tokentypes = num_tokentypes
args = get_args()
# Word embeddings (parallel).
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
vocab_size, self.hidden_size, init_method=self.init_method,
use_cpu_initialization=args.use_cpu_initialization
)
self._word_embeddings_key = "word_embeddings"
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size)
self._position_embeddings_key = "position_embeddings"
# Initialize the position embeddings.
self.init_method(self.position_embeddings.weight)
# Token type embedding.
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
self._tokentype_embeddings_key = "tokentype_embeddings"
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
else:
self.tokentype_embeddings = None
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
if torch.distributed.get_rank() == 0:
print("FINISH WORD EMBEDDING", self.word_embeddings)
def zero_parameters(self):
"""Zero out all parameters in embedding."""
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
self.position_embeddings.weight.data.fill_(0)
self.position_embeddings.weight.shared = True
if self.num_tokentypes > 0:
self.tokentype_embeddings.weight.data.fill_(0)
self.tokentype_embeddings.weight.shared = True
def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
This allows us to load the model normally and then add this embedding.
"""
if self.tokentype_embeddings is not None:
raise Exception("tokentype embeddings is already initialized")
if torch.distributed.get_rank() == 0:
print("adding embedding for {} tokentypes".format(num_tokentypes), flush=True)
self.num_tokentypes = num_tokentypes
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
else:
assert self.tokentype_embeddings is None
# Dropout.
embeddings = self.embedding_dropout(embeddings)
return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False):
"""For easy load."""
state_dict_ = {}
state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict(destination, prefix, keep_vars)
state_dict_[self._position_embeddings_key] = self.position_embeddings.state_dict(destination, prefix, keep_vars)
if self.num_tokentypes > 0:
state_dict_[self._tokentype_embeddings_key] = self.tokentype_embeddings.state_dict(
destination, prefix, keep_vars
)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
# Word embedding.
if self._word_embeddings_key in state_dict:
state_dict_ = state_dict[self._word_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if "word_embeddings" in key:
state_dict_[key.split("word_embeddings.")[1]] = state_dict[key]
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
# Position embedding.
if self._position_embeddings_key in state_dict:
state_dict_ = state_dict[self._position_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if "position_embeddings" in key:
state_dict_[key.split("position_embeddings.")[1]] = state_dict[key]
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
# Tokentype embedding.
if self.num_tokentypes > 0:
state_dict_ = {}
if self._tokentype_embeddings_key in state_dict:
state_dict_ = state_dict[self._tokentype_embeddings_key]
else:
# for backward compatibility.
for key in state_dict.keys():
if "tokentype_embeddings" in key:
state_dict_[key.split("tokentype_embeddings.")[1]] = state_dict[key]
if len(state_dict_.keys()) > 0:
self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
else:
print(
"***WARNING*** expected tokentype embeddings in the " "checkpoint but could not find it", flush=True
)
return model
class TransformerLanguageModel(MegatronModule):
"""Transformer language model.
Arguments:
transformer_hparams: transformer hyperparameters
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
class GPTModel(MegatronModule):
"""GPT-2 Language model."""
def __init__(
self,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0,
add_encoder=True,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=False,
pre_process=True,
post_process=True,
num_tokentypes:int = 0,
parallel_output: bool = True,
pre_process: bool = True,
post_process: bool = True,
cpu_offload: bool = False,
):
super().__init__()
args = get_args()
self.pre_process = pre_process
self.post_process = post_process
self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes
self.init_method = init_method
self.add_encoder = add_encoder
self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler
self.encoder_hidden_state = None
# Embeddings.
if self.pre_process:
self.embedding = Embedding(
self.hidden_size,
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
self.init_method,
self.num_tokentypes,
)
self._embedding_key = "embedding"
# Transformer.
# Encoder (usually set to True, False if part of an encoder-decoder
# architecture and in encoder-only stage).
if self.add_encoder:
self.encoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process,
)
self._encoder_key = "encoder"
else:
self.encoder = None
# Decoder (usually set to False, True if part of an encoder-decoder
# architecture and in decoder-only stage).
if self.add_decoder:
# Temporary assertion until we verify correctness of pipeline parallelism
# implementation of T5.
assert (
args.pipeline_model_parallel_size == 1
), "pipeline parallelism is not supported in the presence of decoder"
self.decoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
layer_type=LayerType.decoder,
self_attn_mask_type=self.decoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process,
)
self._decoder_key = "decoder"
else:
self.decoder = None
if self.post_process:
# Pooler.
if self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = "pooler"
def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
if self.add_encoder and self.add_decoder:
assert (
len(input_tensor) == 1
), "input_tensor should only be length 1 for stage with both encoder and decoder"
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_encoder:
assert len(input_tensor) == 1, "input_tensor should only be length 1 for stage with only encoder"
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_decoder:
if len(input_tensor) == 2:
self.decoder.set_input_tensor(input_tensor[0])
self.encoder_hidden_state = input_tensor[1]
elif len(input_tensor) == 1:
self.decoder.set_input_tensor(None)
self.encoder_hidden_state = input_tensor[0]
else:
raise Exception("input_tensor must have either length 1 or 2")
else:
raise Exception("Stage must have at least either encoder or decoder")
def forward(
self,
enc_input_ids,
enc_position_ids,
enc_attn_mask,
dec_input_ids=None,
dec_position_ids=None,
dec_attn_mask=None,
enc_dec_attn_mask=None,
tokentype_ids=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None,
pooling_sequence_index=0,
enc_hidden_states=None,
output_enc_hidden=False,
):
# Encoder embedding.
if self.pre_process:
encoder_input = self.embedding(enc_input_ids, enc_position_ids, tokentype_ids=tokentype_ids)
else:
encoder_input = None
# Run encoder.
if enc_hidden_states is None:
if self.encoder is not None:
encoder_output = self.encoder(
encoder_input,
enc_attn_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len,
)
else:
encoder_output = self.encoder_hidden_state
else:
encoder_output = enc_hidden_states.to(encoder_input.dtype)
if self.post_process:
if self.add_pooler:
pooled_output = self.pooler(encoder_output, pooling_sequence_index)
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if not self.add_decoder or output_enc_hidden:
if self.add_pooler and self.post_process:
return encoder_output, pooled_output
else:
return encoder_output
# Decoder embedding.
if self.pre_process:
decoder_input = self.embedding(dec_input_ids, dec_position_ids)
else:
decoder_input = None
# Run decoder.
decoder_output = self.decoder(
decoder_input,
dec_attn_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len,
)
if self.add_pooler and self.post_process:
return decoder_output, encoder_output, pooled_output
else:
return decoder_output, encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False):
"""For easy load."""
state_dict_ = {}
if self.pre_process:
state_dict_[self._embedding_key] = self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
if self.add_encoder:
state_dict_[self._encoder_key] = self.encoder.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
if self.post_process:
if self.add_pooler:
state_dict_[self._pooler_key] = self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
if self.add_decoder:
state_dict_[self._decoder_key] = self.decoder.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
# Embedding.
if self.pre_process:
if self._embedding_key in state_dict:
state_dict_ = state_dict[self._embedding_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if "_embeddings" in key:
state_dict_[key] = state_dict[key]
self.embedding.load_state_dict(state_dict_, strict=strict)
# Encoder.
if self.add_encoder:
if self._encoder_key in state_dict:
state_dict_ = state_dict[self._encoder_key]
# For backward compatibility.
elif "transformer" in state_dict:
state_dict_ = state_dict["transformer"]
else:
# For backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if "transformer." in key:
state_dict_[key.split("transformer.")[1]] = state_dict[key]
self.forward_context = contextlib.nullcontext
if cpu_offload:
self.forward_context = torch.autograd.graph.save_on_cpu
# For backward compatibility.
state_dict_self_attention = {}
for key in state_dict_.keys():
if ".attention." in key:
state_dict_self_attention[key.replace(".attention.", ".self_attention.")] = state_dict_[key]
else:
state_dict_self_attention[key] = state_dict_[key]
state_dict_ = state_dict_self_attention
self.encoder.load_state_dict(state_dict_, strict=strict)
# Pooler.
if self.post_process:
if self.add_pooler:
assert "pooler" in state_dict, "could not find data for pooler in the checkpoint"
self.pooler.load_state_dict(state_dict[self._pooler_key], strict=strict)
# Decoder.
if self.add_decoder:
assert "decoder" in state_dict, "could not find data for pooler in the checkpoint"
self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict)
def post_language_model_processing(lm_output, labels, logit_weights, parallel_output, fp16_lm_cross_entropy):
# Output.
output = parallel_lm_logits(lm_output, logit_weights, parallel_output)
if labels is None:
return output
else:
if fp16_lm_cross_entropy:
assert output.dtype == torch.half
loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels)
else:
loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)
return loss
def module_size(m: torch.nn.Module, only_trainable: bool = False):
"""
returns the total number of parameters used by `m` (only counting
shared parameters once); if `only_trainable` is True, then only
includes parameters with `requires_grad = True`
"""
parameters = list(m.parameters())
if only_trainable:
parameters = [p for p in parameters if p.requires_grad]
unique = {p.data_ptr(): p for p in parameters}.values()
return sum(p.numel() for p in unique)
class GPTModel(MegatronModule):
"""GPT-2 Language model."""
def __init__(self, num_tokentypes=0, parallel_output=True, pre_process=True, post_process=True, cpu_offload=False):
super(GPTModel, self).__init__()
args = get_args()
self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.cpu_offload = cpu_offload
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
add_pooler=False,
encoder_attn_mask_type=AttnMaskType.causal,
init_method=init_method_normal(args.init_method_std),
scaled_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers),
scaled_init_method=scaled_init_method_normal(
args.init_method_std, args.num_layers
),
pre_process=self.pre_process,
post_process=self.post_process,
)
......@@ -1471,54 +90,22 @@ class GPTModel(MegatronModule):
attention_mask,
labels=None,
tokentype_ids=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None,
inference_params=None,
):
with torch.autograd.graph.save_on_cpu() if self.cpu_offload else contextlib.nullcontext():
with self.forward_context():
lm_output = self.language_model(
input_ids,
position_ids,
attention_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len,
input_ids, position_ids, attention_mask, inference_params=inference_params
)
if self.post_process:
return post_language_model_processing(
lm_output, labels, self.word_embeddings_weight(), self.parallel_output, self.fp16_lm_cross_entropy
lm_output,
# note(mkozuki): Am I overlooking some order of dim change?
labels.t().contiguous(),
self.word_embeddings_weight(),
self.parallel_output,
self.fp16_lm_cross_entropy,
)
else:
return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False):
state_dict_ = {}
state_dict_[self._language_model_key] = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
# Save word_embeddings.
if self.post_process and not self.pre_process:
state_dict_[self._word_embeddings_for_head_key] = self.word_embeddings.state_dict(
destination, prefix, keep_vars
)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
# Load word_embeddings.
if self.post_process and not self.pre_process:
self.word_embeddings.load_state_dict(state_dict[self._word_embeddings_for_head_key], strict=strict)
if self._language_model_key in state_dict:
state_dict = state_dict[self._language_model_key]
self.language_model.load_state_dict(state_dict, strict=strict)
def gpt_model_provider(pre_process=True, post_process=False, cpu_offload=False):
model = GPTModel(num_tokentypes=0, parallel_output=True, pre_process=pre_process, post_process=post_process, cpu_offload=cpu_offload)
if torch.distributed.get_rank() == 0:
init_dict = {"pre_process":pre_process, "post_process":post_process, "cpu_offload":cpu_offload}
print("Initialized GPT-2 w/:", json.dumps(init_dict))
n_params = module_size(model) * parallel_state.get_tensor_model_parallel_world_size() * parallel_state.get_pipeline_model_parallel_world_size()
print("Number of Parameters:", n_params)
return model
# coding=utf-8
# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT-2 model."""
import enum
import math
import contextlib
import json
import torch
import torch.nn.functional as F
import apex.transformer.utils
from apex.transformer.layers import FusedLayerNorm as LayerNorm
from apex.transformer.functional import FusedScaleMaskSoftmax
from apex.transformer import tensor_parallel
from apex.transformer.tensor_parallel.layers import ColumnParallelLinear
from apex.transformer.tensor_parallel.layers import RowParallelLinear
from apex.transformer.tensor_parallel.layers import VocabParallelEmbedding
from apex.transformer.tensor_parallel.mappings import scatter_to_sequence_parallel_region
from apex.transformer import parallel_state
from apex.transformer.testing.global_vars import get_args
from apex.transformer.enums import ModelType
from apex.transformer.enums import LayerType
from apex.transformer.enums import AttnType
from apex.transformer.enums import AttnMaskType
from apex.transformer.log_util import get_transformer_logger
_logger = get_transformer_logger(__name__)
def param_is_not_shared(param: torch.Tensor) -> bool:
return getattr(param, "shared", False)
class MegatronModule(torch.nn.Module):
"""Megatron specific extensions of torch Module with support for pipelining."""
def __init__(self, share_word_embeddings: bool = True) -> None:
super().__init__()
self.share_word_embeddings = share_word_embeddings
def word_embeddings_weight(self):
if self.pre_process:
return self.language_model.embedding.word_embeddings.weight
else:
if not self.share_word_embeddings:
raise Exception('word_embeddings_weight() called for last stage, but share_word_embeddings is false')
return self.word_embeddings.weight
def initialize_word_embeddings(self, init_method_normal):
args = get_args()
if not self.share_word_embeddings:
raise Exception("initialize_word_embeddings() was called but share_word_embeddings is false")
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. Nothing to do if we aren't
# using pipeline parallelism.
if args.pipeline_model_parallel_size == 1:
return
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
if parallel_state.is_pipeline_last_stage() and not self.pre_process:
assert not parallel_state.is_pipeline_first_stage()
self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.word_embeddings = VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std))
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
# Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule.
if not parallel_state.is_pipeline_first_stage(ignore_virtual=True) and self.pre_process:
self.language_model.embedding.zero_parameters()
# Ensure that first and last stages have the same initial parameter
# values.
if torch.distributed.is_initialized():
if parallel_state.is_rank_in_embedding_group():
torch.distributed.all_reduce(self.word_embeddings_weight(),
group=parallel_state.get_embedding_group())
# Ensure that encoder(first stage) and decoder(split stage) position
# embeddings have the same initial parameter values
# NOTE: We don't currently support T5 with the interleaved schedule.
if parallel_state.is_rank_in_position_embedding_group() and \
args.pipeline_model_parallel_split_rank is not None:
# TODO: Support tokentype embedding.
self.language_model.embedding.cuda()
position_embeddings = self.language_model.embedding.position_embeddings
torch.distributed.all_reduce(position_embeddings.weight,
group=parallel_state.get_position_embedding_group())
else:
print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong.")
def get_linear_layer(rows, columns, init_method):
"""Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns)
init_method(layer.weight)
with torch.no_grad():
layer.bias.zero_()
return layer
# NOTE(mkozuki): Avoid inplace op.
def attention_mask_func(attention_scores: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
# attention_scores.masked_fill_(attention_mask, -10000.0)
# return attention_scores
return attention_scores.masked_fill(attention_mask, -10000.0)
def init_method_normal(sigma):
"""Init method based on N(0, sigma)."""
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
return init_
def scaled_init_method_normal(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
class ParallelMLP(MegatronModule):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self, init_method, output_layer_init_method):
super().__init__()
args = get_args()
# Project to 4h.
self.dense_h_to_4h = ColumnParallelLinear(
args.hidden_size,
args.ffn_hidden_size,
gather_output=False,
init_method=init_method,
skip_bias_add=True,
no_async_tensor_model_parallel_allreduce=not args.async_tensor_model_parallel_allreduce,
sequence_parallel_enabled=args.sequence_parallel,
)
self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
# Project back to h.
self.dense_4h_to_h = RowParallelLinear(
args.ffn_hidden_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
sequence_parallel_enabled=args.sequence_parallel,
)
def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel)
# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias
class CoreAttention(MegatronModule):
def __init__(self, layer_number, attn_mask_type=AttnMaskType.padding):
super().__init__()
args = get_args()
self.fp16 = args.fp16
self.bf16 = args.bf16
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
self.attn_mask_type = attn_mask_type
self.sequence_parallel = args.sequence_parallel
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = apex.transformer.utils.divide(
projection_size, world_size
)
self.hidden_size_per_attention_head = apex.transformer.utils.divide(
projection_size, args.num_attention_heads
)
self.num_attention_heads_per_partition = apex.transformer.utils.divide(
args.num_attention_heads, world_size
)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16,
self.bf16,
self.attn_mask_type,
args.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff,
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
def forward(self, query_layer, key_layer, value_layer, attention_mask):
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size = (
query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0),
)
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(
output_size[2], output_size[0] * output_size[1], -1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = torch.empty(
output_size[0] * output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device(),
)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
)
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
if not self.sequence_parallel:
with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
else:
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (
value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3),
)
# change view [sk, b * np, hn]
value_layer = value_layer.view(
value_layer.size(0), output_size[0] * output_size[1], -1
)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(
output_size[0] * output_size[1], output_size[2], -1
)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + (
self.hidden_size_per_partition,
)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def __init__(
self,
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding,
):
super().__init__()
args = get_args()
self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = apex.transformer.utils.divide(
projection_size, args.num_attention_heads
)
self.num_attention_heads_per_partition = apex.transformer.utils.divide(
args.num_attention_heads, world_size
)
# Strided linear layer.
if attention_type == AttnType.self_attn:
self.query_key_value = ColumnParallelLinear(
args.hidden_size,
3 * projection_size,
gather_output=False,
init_method=init_method,
no_async_tensor_model_parallel_allreduce=not args.async_tensor_model_parallel_allreduce,
sequence_parallel_enabled=args.sequence_parallel,
)
else:
assert attention_type == AttnType.cross_attn
self.query = ColumnParallelLinear(
args.hidden_size,
projection_size,
gather_output=False,
init_method=init_method,
no_async_tensor_model_parallel_allreduce=not args.async_tensor_model_parallel_allreduce,
sequence_parallel_enabled=args.sequence_parallel,
)
self.key_value = ColumnParallelLinear(
args.hidden_size,
2 * projection_size,
gather_output=False,
init_method=init_method,
no_async_tensor_model_parallel_allreduce=not args.async_tensor_model_parallel_allreduce,
sequence_parallel_enabled=args.sequence_parallel,
)
self.core_attention = CoreAttention(self.layer_number, self.attn_mask_type)
self.checkpoint_core_attention = args.recompute_granularity == "selective"
# Output.
self.dense = RowParallelLinear(
projection_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
sequence_parallel_enabled=args.sequence_parallel,
)
def _checkpointed_attention_forward(
self, query_layer, key_layer, value_layer, attention_mask
):
"""Forward method with activation checkpointing."""
def custom_forward(*inputs):
query_layer = inputs[0]
key_layer = inputs[1]
value_layer = inputs[2]
attention_mask = inputs[3]
output_ = self.core_attention(
query_layer, key_layer, value_layer, attention_mask
)
return output_
hidden_states = tensor_parallel.checkpoint(
custom_forward, False, query_layer, key_layer, value_layer, attention_mask
)
return hidden_states
def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=torch.cuda.current_device(),
)
def forward(
self, hidden_states, attention_mask, encoder_output=None, inference_params=None
):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size
)
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size
)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory,
inference_value_memory,
)
else:
(
inference_key_memory,
inference_value_memory,
) = inference_params.key_value_memory_dict[self.layer_number]
# =====================
# Query, Key, and Value
# =====================
if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(
query_layer,
key_layer,
value_layer,
) = tensor_parallel.utils.split_tensor_along_last_dim(mixed_x_layer, 3)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head,
)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(
key_layer,
value_layer,
) = tensor_parallel.utils.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
query_layer = query_layer.view(*new_tensor_shape)
# ==================================
# Adjust key and value for inference
# ==================================
if inference_params:
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
# Copy key and values.
inference_key_memory[
sequence_start:sequence_end, batch_start:batch_end, ...
] = key_layer
inference_value_memory[
sequence_start:sequence_end, batch_start:batch_end, ...
] = value_layer
key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[
:sequence_end, batch_start:batch_end, ...
]
# ==================================
# core attention computation
# ==================================
if self.checkpoint_core_attention:
context_layer = self._checkpointed_attention_forward(
query_layer, key_layer, value_layer, attention_mask
)
else:
context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask
)
# =================
# Output. [sq, b, h]
# =================
output, bias = self.dense(context_layer)
return output, bias
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = residual + out
return out
def get_bias_dropout_add(training):
def _bias_dropout_add(x, bias, residual, prob):
return bias_dropout_add(x, bias, residual, prob, training)
return _bias_dropout_add
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def __init__(
self,
init_method,
output_layer_init_method,
layer_number,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
drop_path_rate=0.0,
):
args = get_args()
super().__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm = (
args.apply_residual_connection_post_layernorm
)
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
# no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel_enabled=args.sequence_parallel,
)
# Self attention.
self.self_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type,
)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
# note(mkozuki)
# self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
assert drop_path_rate <= 0.0
self.drop_path = None
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
# no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel_enabled=args.sequence_parallel,
)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.cross_attn,
)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
# no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel_enabled=args.sequence_parallel,
)
# MLP
# note(mkozuki)
assert args.num_experts is None
# if args.num_experts is not None:
# self.mlp = SwitchMLP(init_method, output_layer_init_method)
# else:
# self.mlp = ParallelMLP(init_method, output_layer_init_method)
self.mlp = ParallelMLP(init_method, output_layer_init_method)
# Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
self.bias_dropout_add_exec_handler = (
contextlib.nullcontext if use_nvfuser else torch.enable_grad
)
def forward(
self,
hidden_states,
attention_mask,
encoder_output=None,
enc_dec_attn_mask=None,
inference_params=None,
):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = self.self_attention(
layernorm_output, attention_mask, inference_params=inference_params
)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
if self.drop_path is None:
bias_dropout_add_func = get_bias_dropout_add(self.training)
with self.bias_dropout_add_exec_handler():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout,
)
else:
out = torch.nn.functional.dropout(
attention_output + attention_bias,
p=self.hidden_dropout,
training=self.training,
)
layernorm_input = residual + self.drop_path(out)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.layer_type == LayerType.decoder:
attention_output, attention_bias = self.inter_attention(
layernorm_output, enc_dec_attn_mask, encoder_output=encoder_output
)
# residual connection
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
with self.bias_dropout_add_exec_handler():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout,
)
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
if self.drop_path is None:
with self.bias_dropout_add_exec_handler():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout,
)
else:
out = torch.nn.functional.dropout(
mlp_output + mlp_bias, p=self.hidden_dropout, training=self.training
)
output = residual + self.drop_path(out)
return output
class ParallelTransformer(MegatronModule):
"""Transformer class."""
def __init__(
self,
init_method,
output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
post_layer_norm=True,
pre_process=True,
post_process=True,
drop_path_rate=0.0,
):
super().__init__()
args = get_args()
self.layer_type = layer_type
self.model_type = args.model_type
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
self.post_layer_norm = post_layer_norm
self.pre_process = pre_process
self.post_process = post_process
self.input_tensor = None
self.drop_path_rate = drop_path_rate
# Store activation checkpoiting flag.
self.recompute_granularity = args.recompute_granularity
self.recompute_method = args.recompute_method
self.recompute_num_layers = args.recompute_num_layers
self.distribute_saved_activations = (
args.distribute_saved_activations and not args.sequence_parallel
)
self.sequence_parallel = args.sequence_parallel
# Number of layers.
self.num_layers = get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder
)
self.drop_path_rates = [
rate.item()
for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)
]
# Transformer layers.
def build_layer(layer_number):
return ParallelTransformerLayer(
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1],
)
if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, (
"num_layers_per_stage must be divisible by "
"virtual_pipeline_model_parallel_size"
)
assert args.model_type != ModelType.encoder_and_decoder
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self.num_layers = (
self.num_layers // args.virtual_pipeline_model_parallel_size
)
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size
) + (parallel_state.get_pipeline_model_parallel_rank() * self.num_layers)
else:
# Each stage gets a contiguous set of layers.
if (
args.model_type == ModelType.encoder_and_decoder
and parallel_state.get_pipeline_model_parallel_world_size() > 1
):
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
if layer_type == LayerType.encoder:
offset = pipeline_rank * self.num_layers
else:
num_ranks_in_enc = args.pipeline_model_parallel_split_rank
offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
else:
offset = (
parallel_state.get_pipeline_model_parallel_rank() * self.num_layers
)
if self.num_layers == 0:
# When a standalone embedding stage is used (e.g.,
# args.standalone_embedding_stage == True), virtual pipeline ranks
# on pipeline rank 0 will have zero transformer layers assigned to
# them. This results in the model's input and output tensors to be
# the same, which will cause failure for certain output tensor
# optimizations (e.g., pipeline output deallocation). To remedy
# this, we assign a 'no-op' layer on these ranks, which will
# disconnect the input tensor from the output tensor.
self.num_layers = 1
self.layers = torch.nn.ModuleList([NoopTransformerLayer(1)])
else:
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)]
)
if self.post_process and self.post_layer_norm:
# Final layer norm before output.
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
# no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel_enabled=args.sequence_parallel,
)
def _get_layer(self, layer_number):
return self.layers[layer_number]
def _checkpointed_forward(
self, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask
):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
x_ = inputs[0]
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
return x_
return custom_forward
if self.recompute_method == "uniform":
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states = tensor_parallel.random.checkpoint(
custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations,
hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask,
)
l += self.recompute_num_layers
elif self.recompute_method == "block":
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.recompute_num_layers:
hidden_states = tensor_parallel.random.checkpoint(
custom(l, l + 1),
self.distribute_saved_activations,
hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask,
)
else:
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask
)
else:
raise ValueError("Invalid activation recompute method.")
return hidden_states
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(
self,
hidden_states,
attention_mask,
encoder_output=None,
enc_dec_attn_mask=None,
inference_params=None,
):
# hidden_states: [s, b, h]
# Checks.
if inference_params:
assert (
self.recompute_granularity is None
), "inference does not work with activation checkpointing"
if not self.pre_process:
# See set_input_tensor()
hidden_states = self.input_tensor
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
# hidden_states = mpu.make_viewless_tensor(hidden_states, requires_grad=True, keep_graph=True)
if self.sequence_parallel:
rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
else:
rng_context = contextlib.nullcontext()
with rng_context:
# Forward pass.
if self.recompute_granularity == "full":
hidden_states = self._checkpointed_forward(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask
)
else:
for index in range(self.num_layers):
layer = self._get_layer(index)
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params,
)
# Final layer norm.
if self.post_process and self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
def get_num_layers(args, is_encoder_and_decoder_model):
"""Compute the number of transformer layers resident on the current rank."""
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model:
assert args.pipeline_model_parallel_split_rank is not None
# When a standalone embedding stage is used, a rank is taken from
# the encoder's ranks, to be used for the encoder's embedding
# layer. This way, the rank referenced by the 'split rank' remains
# the same whether or not a standalone embedding stage is used.
num_ranks_in_encoder = (
args.pipeline_model_parallel_split_rank - 1
if args.standalone_embedding_stage
else args.pipeline_model_parallel_split_rank
)
num_ranks_in_decoder = (
args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
)
assert args.num_layers % num_ranks_in_encoder == 0, (
"num_layers (%d) must be divisible by number of ranks given to encoder (%d)"
% (
args.num_layers,
num_ranks_in_encoder,
)
)
assert args.num_layers % num_ranks_in_decoder == 0, (
"num_layers (%d) must be divisible by number of ranks given to decoder (%d)"
% (
args.num_layers,
num_ranks_in_decoder,
)
)
if parallel_state.is_pipeline_stage_before_split():
num_layers = (
0
if args.standalone_embedding_stage
and parallel_state.get_pipeline_model_parallel_rank() == 0
else args.num_layers // num_ranks_in_encoder
)
else:
num_layers = args.num_layers // num_ranks_in_decoder
else:
assert (
args.num_layers % args.transformer_pipeline_model_parallel_size == 0
), "num_layers must be divisible by transformer_pipeline_model_parallel_size"
# When a standalone embedding stage is used, all transformer layers
# are divided among pipeline rank >= 1, while on pipeline rank 0,
# ranks either contain the input embedding layer (virtual pp rank 0),
# or no layers at all (virtual pp rank >= 1).
num_layers = (
0
if args.standalone_embedding_stage
and parallel_state.get_pipeline_model_parallel_rank() == 0
else args.num_layers // args.transformer_pipeline_model_parallel_size
)
else:
num_layers = args.num_layers
return num_layers
class NoopTransformerLayer(MegatronModule):
"""A single 'no-op' transformer layer.
The sole purpose of this layer is for when a standalone embedding layer
is used (i.e., args.standalone_embedding_stage == True). In this case,
zero transformer layers are assigned when pipeline rank == 0. Additionally,
when virtual pipeline rank >= 1, zero total model parameters are created
(virtual rank 0 contains the input embedding). This results in the model's
input and output tensors being the same, which causes an error when
performing certain memory optimiations on the output tensor (e.g.,
deallocating it). Thus, this layer disconnects the input from the output
via a clone. Since ranks containing a no-op layer are generally under-
utilized (both compute and memory), there's no worry of any performance
degredation.
"""
def __init__(self, layer_number):
super().__init__()
self.layer_number = layer_number
def forward(
self,
hidden_states,
attention_mask,
encoder_output=None,
enc_dec_attn_mask=None,
inference_params=None,
):
return hidden_states.clone()
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
"""LM logits using word embedding weights."""
args = get_args()
# Parallel logits.
if args.async_tensor_model_parallel_allreduce or args.sequence_parallel:
input_parallel = input_
model_parallel = parallel_state.get_tensor_model_parallel_world_size() > 1
async_grad_allreduce = (
args.async_tensor_model_parallel_allreduce
and model_parallel
and not args.sequence_parallel
)
else:
input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_)
async_grad_allreduce = False
# Matrix multiply.
# logits_parallel = tensor_parallel.layers.LinearWithGradAccumulationAndAsyncCommunication.apply(
# input_parallel, word_embeddings_weight, bias, args.gradient_accumulation_fusion, async_grad_allreduce, args.sequence_parallel)
logits_parallel = (
tensor_parallel.layers.linear_with_grad_accumulation_and_async_allreduce(
input_parallel,
word_embeddings_weight,
bias,
args.gradient_accumulation_fusion,
async_grad_allreduce,
args.sequence_parallel,
)
)
# Gather if needed.
if parallel_output:
return logits_parallel
return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(
num_tokentypes,
add_pooler,
encoder_attn_mask_type,
init_method=None,
scaled_init_method=None,
add_encoder=True,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
pre_process=True,
post_process=True,
):
"""Build language model and return along with the key to save."""
args = get_args()
if init_method is None:
init_method = init_method_normal(args.init_method_std)
if scaled_init_method is None:
scaled_init_method = scaled_init_method_normal(
args.init_method_std, args.num_layers
)
# Language model.
language_model = TransformerLanguageModel(
init_method,
scaled_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes,
add_encoder=add_encoder,
add_decoder=add_decoder,
decoder_attn_mask_type=decoder_attn_mask_type,
add_pooler=add_pooler,
pre_process=pre_process,
post_process=post_process,
)
# key used for checkpoints.
language_model_key = "language_model"
return language_model, language_model_key
class Pooler(MegatronModule):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Arguments:
hidden_size: hidden size
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def __init__(self, hidden_size, init_method):
super().__init__()
args = get_args()
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.sequence_parallel = args.sequence_parallel
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [s, b, h]
# sequence_index: index of the token to pool.
# gather data along sequence dimensions
# same pooler is run on all tensor parallel nodes
if self.sequence_parallel:
hidden_states = tensor_parallel.mappings.gather_from_sequence_parallel_region(hidden_states)
pooled = hidden_states[sequence_index, :, :]
pooled = self.dense(pooled)
pooled = torch.tanh(pooled)
return pooled
class Embedding(MegatronModule):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(
self,
hidden_size,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
init_method,
num_tokentypes=0,
):
super().__init__()
self.hidden_size = hidden_size
self.init_method = init_method
self.num_tokentypes = num_tokentypes
args = get_args()
# Word embeddings (parallel).
self.word_embeddings = VocabParallelEmbedding(
vocab_size, self.hidden_size, init_method=self.init_method
)
self._word_embeddings_key = "word_embeddings"
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(
max_sequence_length, self.hidden_size
)
self._position_embeddings_key = "position_embeddings"
# Initialize the position embeddings.
self.init_method(self.position_embeddings.weight)
# Token type embedding.
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
self._tokentype_embeddings_key = "tokentype_embeddings"
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(
self.num_tokentypes, self.hidden_size
)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
else:
self.tokentype_embeddings = None
self.fp32_residual_connection = args.fp32_residual_connection
self.sequence_parallel = args.sequence_parallel
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
def zero_parameters(self):
"""Zero out all parameters in embedding."""
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
self.position_embeddings.weight.data.fill_(0)
self.position_embeddings.weight.shared = True
if self.num_tokentypes > 0:
self.tokentype_embeddings.weight.fill_(0)
self.tokentype_embeddings.weight.shared = True
def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
This allows us to load the model normally and then add this embedding.
"""
if self.tokentype_embeddings is not None:
raise Exception("tokentype embeddings is already initialized")
if torch.distributed.get_rank() == 0:
print(
"adding embedding for {} tokentypes".format(num_tokentypes), flush=True
)
self.num_tokentypes = num_tokentypes
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
else:
assert self.tokentype_embeddings is None
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
embeddings = embeddings.float()
# Dropout.
if self.sequence_parallel:
embeddings = scatter_to_sequence_parallel_region(embeddings)
with tensor_parallel.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
embeddings = self.embedding_dropout(embeddings)
return embeddings
class TransformerLanguageModel(MegatronModule):
"""Transformer language model.
Arguments:
transformer_hparams: transformer hyperparameters
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(
self,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0,
add_encoder=True,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=False,
pre_process=True,
post_process=True,
):
super().__init__()
args = get_args()
self.pre_process = pre_process
self.post_process = post_process
self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes
self.init_method = init_method
self.add_encoder = add_encoder
self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler
self.encoder_hidden_state = None
# Embeddings.
if self.pre_process:
self.embedding = Embedding(
self.hidden_size,
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
self.init_method,
self.num_tokentypes,
)
self._embedding_key = "embedding"
# Transformer.
# Encoder (usually set to True, False if part of an encoder-decoder
# architecture and in encoder-only stage).
if self.add_encoder:
self.encoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process,
)
self._encoder_key = "encoder"
else:
self.encoder = None
# Decoder (usually set to False, True if part of an encoder-decoder
# architecture and in decoder-only stage).
if self.add_decoder:
self.decoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
layer_type=LayerType.decoder,
self_attn_mask_type=self.decoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process,
)
self._decoder_key = "decoder"
else:
self.decoder = None
if self.post_process:
# Pooler.
if self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = "pooler"
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
if self.add_encoder and self.add_decoder:
assert (
len(input_tensor) == 1
), "input_tensor should only be length 1 for stage with both encoder and decoder"
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_encoder:
assert (
len(input_tensor) == 1
), "input_tensor should only be length 1 for stage with only encoder"
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_decoder:
if len(input_tensor) == 2:
self.decoder.set_input_tensor(input_tensor[0])
self.encoder_hidden_state = input_tensor[1]
elif len(input_tensor) == 1:
self.decoder.set_input_tensor(None)
self.encoder_hidden_state = input_tensor[0]
else:
raise Exception("input_tensor must have either length 1 or 2")
else:
raise Exception("Stage must have at least either encoder or decoder")
def forward(
self,
enc_input_ids,
enc_position_ids,
enc_attn_mask,
dec_input_ids=None,
dec_position_ids=None,
dec_attn_mask=None,
enc_dec_attn_mask=None,
tokentype_ids=None,
inference_params=None,
pooling_sequence_index=0,
enc_hidden_states=None,
output_enc_hidden=False,
):
args = get_args()
# Encoder embedding.
if self.pre_process:
encoder_input = self.embedding(
enc_input_ids, enc_position_ids, tokentype_ids=tokentype_ids
)
else:
encoder_input = None
# Run encoder.
if enc_hidden_states is None:
if self.encoder is not None:
encoder_output = self.encoder(
encoder_input, enc_attn_mask, inference_params=inference_params
)
else:
encoder_output = self.encoder_hidden_state
else:
encoder_output = enc_hidden_states.to(encoder_input.dtype)
if self.post_process:
if self.add_pooler:
pooled_output = self.pooler(encoder_output, pooling_sequence_index)
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if not self.add_decoder or output_enc_hidden:
if self.add_pooler and self.post_process:
return encoder_output, pooled_output
else:
return encoder_output
# Decoder embedding.
if self.pre_process:
decoder_input = self.embedding(dec_input_ids, dec_position_ids)
else:
decoder_input = None
# Run decoder.
decoder_output = self.decoder(
decoder_input,
dec_attn_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params,
)
if self.add_pooler and self.post_process:
return decoder_output, encoder_output, pooled_output
else:
return decoder_output, encoder_output
def post_language_model_processing(
lm_output, labels, logit_weights, parallel_output, fp16_lm_cross_entropy
):
# Output.
output = parallel_lm_logits(lm_output, logit_weights, parallel_output)
if labels is None:
return output
else:
if fp16_lm_cross_entropy:
assert output.dtype == torch.half
loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels)
else:
loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)
return loss
def module_size(m: torch.nn.Module, only_trainable: bool = False):
"""
returns the total number of parameters used by `m` (only counting
shared parameters once); if `only_trainable` is True, then only
includes parameters with `requires_grad = True`
"""
parameters = list(m.parameters())
if only_trainable:
parameters = [p for p in parameters if p.requires_grad]
unique = {p.data_ptr(): p for p in parameters}.values()
return sum(p.numel() for p in unique)
......@@ -8,11 +8,13 @@ else:
HAS_TORCH_UCC = True
print("Use UCC as backend of Pipeline Parallel ProcessGroups")
from apex.transformer.enums import ModelType
from apex.transformer import tensor_parallel
from apex.transformer import parallel_state
from apex.transformer.log_util import set_logging_level
from apex.transformer.tensor_parallel import vocab_parallel_cross_entropy
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import unwrap_model
from apex.transformer.pipeline_parallel.utils import (
average_losses_across_data_parallel_group,
)
......@@ -148,8 +150,24 @@ def train(
batch = generate_fancy_data_labels(sequence_len, batch_size)
optim.zero_grad()
forward_backward_func(
fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape, async_comm=async_comm,
fwd_step_func,
batch,
model,
forward_only=False,
tensor_shape=tensor_shape,
async_comm=async_comm,
sequence_parallel_enabled=global_vars.get_args().sequence_parallel,
)
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if parallel_state.get_tensor_model_parallel_world_size() > 1 and global_vars.get_args().sequence_parallel:
for model_module in model:
unwrapped_model = unwrap_model(model_module)
for param in unwrapped_model.parameters():
if getattr(param, 'sequence_parallel_enabled', False):
grad = param.grad
torch.distributed.all_reduce(grad, group=parallel_state.get_tensor_model_parallel_group())
optim.step()
......@@ -169,13 +187,15 @@ if __name__ == "__main__":
init = True
try:
for virtual_pipeline_model_parallel_size in (2, None):
async_comm = virtual_pipeline_model_parallel_size is None
args = global_vars.get_args()
async_comm = not args.sequence_parallel and virtual_pipeline_model_parallel_size is None
data_idx = 0
ONCE = False
if init:
init = False
args = global_vars.get_args()
args.padded_vocab_size = 128 # needed in standalone gpt
args.model_type = ModelType.encoder_or_decoder
batch_size = args.global_batch_size
micro_batch_size = args.micro_batch_size
setup_microbatch_calculator(
......@@ -201,7 +221,7 @@ if __name__ == "__main__":
tensor_parallel.random.model_parallel_cuda_manual_seed(0)
model = build_model(
bert_model_provider,
wrap_with_ddp=True,
wrap_with_ddp=parallel_state.get_data_parallel_world_size() > 1,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
cpu_offload=args.cpu_offload,
)
......
......@@ -12,8 +12,10 @@ else:
print("Use UCC as backend of Pipeline Parallel ProcessGroups")
from apex.transformer import parallel_state
from apex.transformer.enums import ModelType
from apex.transformer.tensor_parallel import model_parallel_cuda_manual_seed
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import unwrap_model
from apex.transformer.pipeline_parallel.utils import (
average_losses_across_data_parallel_group,
)
......@@ -132,10 +134,25 @@ def train(model, optim, pipeline_model_parallel_size, async_comm):
print("finished making batch...")
optim.zero_grad()
fwd_bwd_func(
fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape, async_comm=async_comm
fwd_step_func,
batch,
model,
forward_only=False,
tensor_shape=tensor_shape,
async_comm=async_comm,
sequence_parallel_enabled=args.sequence_parallel,
)
if torch.distributed.get_rank() == 0:
print("finished forward step")
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if parallel_state.get_tensor_model_parallel_world_size() > 1 and global_vars.get_args().sequence_parallel:
for model_module in model:
unwrapped_model = unwrap_model(model_module)
for param in unwrapped_model.parameters():
if getattr(param, 'sequence_parallel_enabled', False):
grad = param.grad
torch.distributed.all_reduce(grad, group=parallel_state.get_tensor_model_parallel_group())
optim.step()
if torch.distributed.get_rank() == 0:
print("finished iter", i)
......@@ -145,16 +162,17 @@ def train(model, optim, pipeline_model_parallel_size, async_comm):
if __name__ == "__main__":
init = True
for async_comm in (False, True):
global_vars.set_global_variables()
for async_comm in (False,) if global_vars.get_args().sequence_parallel else (False, True):
global fancy_data
global effective_length
if init:
init = False
global_vars.set_global_variables()
fancy_data = download_fancy_data()
args = global_vars.get_args()
args.model_type = ModelType.encoder_or_decoder
effective_length = fancy_data.size(0) // args.seq_length
effective_length = fancy_data.size(0) - args.seq_length
......@@ -189,7 +207,7 @@ if __name__ == "__main__":
model_parallel_cuda_manual_seed(0)
model = build_model(
gpt_model_provider,
wrap_with_ddp=True,
wrap_with_ddp=parallel_state.get_data_parallel_world_size() > 1,
virtual_pipeline_model_parallel_size=None,
cpu_offload=args.cpu_offload,
)
......
import logging
import unittest
import typing
import torch
import torch.nn as nn
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import layers
from apex.transformer.testing.commons import set_random_seed
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("apex").setLevel(logging.WARNING)
# N.B. (mkozuki): Disable TF32 matrix multiply.
# N.B.(mkozuki): Disable TF32 matrix multiply.
# Matrices used in this test are so small that TF32 matmul
# can be less precise so that `self.assertEqual` raises.
torch.backends.cuda.matmul.allow_tf32 = False
......@@ -23,13 +25,114 @@ torch.backends.cuda.matmul.allow_tf32 = False
class TensorParallelLayerTestBase:
BATCH_SIZE: int = 17
SEQUENCE_LENGTH: int = 23
VOCAB_SIZE: int = 48
HIDDEN_SIZE: int = 16
INPUT_SIZE_COEFF: int = 13
OUTPUT_SIZE_COEFF: int = 17
SEED: int = 123
BATCH_SIZE: int = 8
SEQUENCE_LENGTH: int = 128
VOCAB_SIZE: int = 1024
HIDDEN_SIZE: int = 256
INPUT_SIZE_COEFF: int = 256
OUTPUT_SIZE_COEFF: int = 256
SEED: int = 123456
@property
def tensor_shape(self) -> typing.Sequence[int]:
return [self.SEQUENCE_LENGTH, self.BATCH_SIZE, self.HIDDEN_SIZE]
@torch.no_grad()
@unittest.skipIf(torch.cuda.device_count() < 2, "Requires >=2 GPUs")
def test_all_gather_parity(self) -> None:
if self.DISTRIBUTED_BACKEND == "ucc":
self.skipTest("torch_ucc does NOT support `torch.distributed._all_gather_base` as of 2022/06/15")
from torch.distributed.distributed_c10d import all_gather, _all_gather_base # NOQA
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
with self.subTest(tensor_model_parallel_world_size=tensor_model_parallel_world_size):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
)
tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank()
cur_tensor_model_device = torch.device(f"cuda:{tensor_model_parallel_rank}")
with torch.no_grad():
tensor = tensor_model_parallel_rank * torch.ones(
self.tensor_shape, dtype=torch.float32, device=cur_tensor_model_device)
numel = tensor.numel()
numel_gathered = tensor_model_parallel_world_size * numel
gathered = torch.empty(
torch.Size((numel_gathered,)),
device=cur_tensor_model_device,
dtype=torch.float32,
requires_grad=False,
)
chunks = [
gathered[i * numel : (i + 1) * numel]
for i in range(tensor_model_parallel_world_size)
]
all_gather(chunks, tensor, group=parallel_state.get_tensor_model_parallel_group())
gathered_for_base = torch.empty(
torch.Size((numel_gathered,)),
device=cur_tensor_model_device,
dtype=torch.float32,
requires_grad=False,
)
_all_gather_base(
gathered_for_base,
tensor,
group=parallel_state.get_tensor_model_parallel_group(),
)
torch.testing.assert_close(gathered, gathered_for_base)
parallel_state.destroy_model_parallel()
@torch.no_grad()
@unittest.skipIf(torch.cuda.device_count() < 2, "Requires >=2 GPUs")
def test_reduce_scatter_parity(self) -> None:
if self.DISTRIBUTED_BACKEND == "ucc":
self.skipTest("torch_ucc does NOT support `torch.distributed._reduce_scatter_base` as of 2022/06/15")
from torch.distributed.distributed_c10d import reduce_scatter, _reduce_scatter_base # NOQA
for tensor_model_parallel_world_size in range(2, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
with self.subTest(tensor_model_parallel_world_size=tensor_model_parallel_world_size):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
)
tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank()
cur_tensor_model_device = torch.device(f"cuda:{tensor_model_parallel_rank}")
with torch.no_grad():
input = torch.cat([
i * torch.ones(self.tensor_shape, dtype=torch.float32, device=cur_tensor_model_device)
for i in range(tensor_model_parallel_world_size)
])
input_list = [t.clone() for t in input.chunk(tensor_model_parallel_world_size)]
output = torch.empty(
self.tensor_shape,
device=cur_tensor_model_device,
dtype=torch.float32,
requires_grad=False,
)
reduce_scatter(
output, input_list,
group=parallel_state.get_tensor_model_parallel_group(),
)
output_for_base = torch.empty(
self.tensor_shape,
device=cur_tensor_model_device,
dtype=torch.float32,
requires_grad=False,
)
_reduce_scatter_base(
output_for_base,
input,
group=parallel_state.get_tensor_model_parallel_group(),
)
torch.testing.assert_close(output, output_for_base)
torch.testing.assert_close(input, torch.cat(input_list))
parallel_state.destroy_model_parallel()
def test_parallel_embedding(self) -> None:
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
......@@ -69,7 +172,7 @@ class TensorParallelLayerTestBase:
loss_torch = torch.mul(output_torch, loss_weight).sum()
loss_torch.backward()
# N.B. (mkozuki): With affine weight initialization on GPU,
# N.B.(mkozuki): With affine weight initialization on GPU,
# it's super difficult to keep the consistency with nn.Embedding.
# Thus, turning on `use_cpu_initialization`.
set_random_seed(self.SEED)
......@@ -171,139 +274,274 @@ class TensorParallelLayerTestBase:
self._affine_weight_init_test_impl(init_device="gpu", is_column_parallel=False)
def test_row_parallel_linear(self) -> None:
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
self._row_parallel_linear_test_impl(False, False, False)
def test_row_parallel_linear_gradient_accumulation_fusion(self) -> None:
self._row_parallel_linear_test_impl(True, False, False)
def test_row_parallel_linear_gradient_accumulation_fusion_in_fp16(self) -> None:
self._row_parallel_linear_test_impl(True, True, False)
@unittest.skipIf(torch.cuda.device_count() < 2, "Sequence Parallel requires >=2 GPUs")
def test_row_parallel_linear_sequence_parallel(self) -> None:
self._row_parallel_linear_test_impl(False, False, True)
# TODO(mkozuki): Merge this with `_column_parallel_linear_test_impl`
# Note that `input_is_parallel` is unique to `RowParallelLinear` which could make the merge complicated.
def _row_parallel_linear_test_impl(
self,
gradient_accumulation_fusion: bool,
accumulation_in_fp16: bool,
sequence_parallel_enabled: bool,
) -> None:
tensor_shape = (
self.SEQUENCE_LENGTH,
self.BATCH_SIZE,
self.HIDDEN_SIZE,
)
for tensor_model_parallel_world_size in range(
1 + int(sequence_parallel_enabled), self.world_size + 1
):
if self.world_size % tensor_model_parallel_world_size:
continue
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
tensor_model_parallel_world_size=tensor_model_parallel_world_size,
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size
tensor_model_parallel_size_=tensor_model_parallel_world_size,
)
input_size: int = self.INPUT_SIZE_COEFF * tensor_model_parallel_world_size
output_size: int = self.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size
set_random_seed(self.SEED)
linear_layer = layers.RowParallelLinear(
input_size,
output_size,
linear = layers.RowParallelLinear(
self.HIDDEN_SIZE,
self.HIDDEN_SIZE,
keep_master_weight_for_test=True,
params_dtype=torch.float32,
use_cpu_initialization=True,
gradient_accumulation_fusion=gradient_accumulation_fusion,
accumulation_in_fp16=accumulation_in_fp16,
sequence_parallel_enabled=sequence_parallel_enabled,
# n.b.(mkozuki): RowParallelLinear is constructed with `input_is_parallel=True`
# by default, e.g. https://github.com/NVIDIA/NeMo/blob/782b4e1652aaa43c8be390d9\
# db0dc89544afa080/nemo/collections/nlp/modules/common/megatron/transformer.py#L204
input_is_parallel=True,
).cuda()
loss_weight = torch.randn(
(self.BATCH_SIZE, output_size)
).cuda()
if accumulation_in_fp16:
linear = linear.half()
# Simulate the situation where fusion of weight grad calculation and gradient accumulation is enabled.
if gradient_accumulation_fusion:
with torch.no_grad():
linear.weight.main_grad = torch.zeros_like(linear.weight)
# Forward and backward
input_tensor = torch.randn(
self.BATCH_SIZE, input_size, requires_grad=True
).cuda()
input_tensor.retain_grad()
output, _ = linear_layer(input_tensor)
with torch.no_grad():
orig_input_tensor = torch.randn(tensor_shape, requires_grad=True, device="cuda")
orig_loss_weight = torch.randn(tensor_shape, device="cuda")
input_tensor = orig_input_tensor.chunk(
chunks=tensor_model_parallel_world_size,
dim=2,
)[parallel_state.get_tensor_model_parallel_rank()].contiguous()
if sequence_parallel_enabled:
loss_weight = orig_loss_weight.chunk(
chunks=tensor_model_parallel_world_size,
dim=0,
)[parallel_state.get_tensor_model_parallel_rank()]
else:
loss_weight = orig_loss_weight
if accumulation_in_fp16:
orig_input_tensor = orig_input_tensor.half()
input_tensor = input_tensor.half()
loss_weight = loss_weight.half()
input_tensor.requires_grad_()
output, _ = linear(input_tensor)
loss = torch.mul(output, loss_weight).sum()
loss.backward()
self.assertIsNotNone(input_tensor.grad)
ref_linear = nn.Linear(
in_features=self.HIDDEN_SIZE,
out_features=self.HIDDEN_SIZE,
bias=False,
device="cuda",
)
with torch.no_grad():
dldy = loss_weight.clone()
x = input_tensor.clone()
a = linear_layer.master_weight.cuda()
dlda = torch.matmul(dldy.t(), x)
dldb = torch.matmul(
torch.ones(self.BATCH_SIZE, 1).cuda().t(), dldy
).view(-1)
dldx = torch.matmul(dldy, a)
with torch.no_grad():
curr_dlda = torch.split(
dlda, self.INPUT_SIZE_COEFF, dim=1
)[parallel_state.get_tensor_model_parallel_rank()].clone()
self.assertEqual(linear_layer.weight.grad, curr_dlda)
self.assertEqual(input_tensor.grad, dldx)
self.assertEqual(linear_layer.bias.grad, dldb)
dldy = orig_loss_weight.clone()
x = orig_input_tensor.clone()
ref_linear.weight.copy_(linear.master_weight)
if accumulation_in_fp16:
ref_linear = ref_linear.half()
x.requires_grad_()
expected_output = ref_linear(x)
expected_loss = torch.mul(expected_output, dldy).sum()
expected_loss.backward()
if not accumulation_in_fp16:
if sequence_parallel_enabled:
torch.testing.assert_close(
actual=output,
expected=expected_output.chunk(
chunks=tensor_model_parallel_world_size,
dim=0,
)[parallel_state.get_tensor_model_parallel_rank()],
)
else:
torch.testing.assert_close(
actual=output,
expected=expected_output,
)
grad_attr_name = "main_grad" if gradient_accumulation_fusion else "grad"
# NOTE(mkozuki): Numerical errors seems to be enlarged by tensor model parallel.
if tensor_model_parallel_world_size == 1:
torch.testing.assert_close(
actual=getattr(linear.weight, grad_attr_name),
expected=ref_linear.weight.grad.chunk(
chunks=tensor_model_parallel_world_size,
dim=0,
)[parallel_state.get_tensor_model_parallel_rank()],
)
parallel_state.destroy_model_parallel()
def test_column_parallel_linear(self):
self._column_parallel_linear_test_impl(False, False)
self._column_parallel_linear_test_impl(False, False, False, False)
def test_column_parallel_linear_no_async(self):
self._column_parallel_linear_test_impl(True, False)
def test_column_parallel_linear_async(self):
self._column_parallel_linear_test_impl(True, False, False, False)
def test_column_parallel_linear_gradient_accumulation_fusion(self):
self._column_parallel_linear_test_impl(False, True)
self._column_parallel_linear_test_impl(False, True, False, False)
def test_column_parallel_linear_gradient_accumulation_fusion_in_fp16(self):
self._column_parallel_linear_test_impl(False, True, True, False)
def test_column_parallel_linear_sequence_parallel(self):
if self.DISTRIBUTED_BACKEND == "ucc":
self.skipTest("Backward's reduce_scatter fails. as of 2022/06/15")
self._column_parallel_linear_test_impl(False, False, False, True)
@unittest.skipIf(torch.cuda.device_count() < 2, "Sequence Parallel requires >= 2 GPUs")
def test_column_parallel_linear_exception(self):
with self.assertRaisesRegex(
RuntimeError,
"`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` cannot be enabled at the same time.",
):
self._column_parallel_linear_test_impl(True, False, False, True)
def _column_parallel_linear_test_impl(
self,
no_async_tensor_model_parallel_allreduce: bool,
async_tensor_model_parallel_allreduce: bool,
gradient_accumulation_fusion: bool,
accumulation_in_fp16: bool,
sequence_parallel_enabled: bool,
):
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
if async_tensor_model_parallel_allreduce and sequence_parallel_enabled:
if tensor_model_parallel_world_size == 1:
continue
with self.subTest(tensor_model_parallel_world_size=tensor_model_parallel_world_size):
if self.world_size % tensor_model_parallel_world_size:
continue
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
)
feature_size_coeff = self.INPUT_SIZE_COEFF
feature_size = feature_size_coeff * tensor_model_parallel_world_size
hidden_size = feature_size
input_tensor_shape = self.tensor_shape
expected_output_shape = self.tensor_shape
# When sequence parallel, `gather_output` is disabled, i.e.,
# output of matmul isn't gathered in dimension of feature/hidden (last dim).
if sequence_parallel_enabled:
expected_output_shape[-1] //= tensor_model_parallel_world_size
# tensor's shape is [sequence length, batch size, hidden size]
set_random_seed(self.SEED)
input_tensor = torch.randn(
self.BATCH_SIZE,
hidden_size,
feature_size,
device="cuda",
requires_grad=True,
)
input_tensor.retain_grad()
loss_weight = torch.randn(
(self.BATCH_SIZE, hidden_size, feature_size,),
device="cuda",
)
linear = layers.ColumnParallelLinear(
feature_size,
feature_size,
self.HIDDEN_SIZE,
self.HIDDEN_SIZE,
bias=False,
keep_master_weight_for_test=True,
params_dtype=torch.float32,
use_cpu_initialization=True,
no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce,
gather_output=not sequence_parallel_enabled,
no_async_tensor_model_parallel_allreduce=not async_tensor_model_parallel_allreduce,
gradient_accumulation_fusion=gradient_accumulation_fusion,
accumulation_in_fp16=accumulation_in_fp16,
sequence_parallel_enabled=sequence_parallel_enabled,
).cuda()
if accumulation_in_fp16:
linear = linear.half()
# Simulate the situation where fusion of weight grad calculation and gradient accumulation happens.
if gradient_accumulation_fusion:
with torch.no_grad():
linear.weight.main_grad = torch.randn_like(linear.weight)
linear.weight.main_grad = torch.zeros_like(linear.weight)
orig_input_tensor = torch.randn(input_tensor_shape, device="cuda", requires_grad=True)
if accumulation_in_fp16:
orig_input_tensor = orig_input_tensor.half()
if sequence_parallel_enabled:
input_tensor = list(
orig_input_tensor.chunk(tensor_model_parallel_world_size, dim=0)
)[parallel_state.get_tensor_model_parallel_rank()]
else:
input_tensor = orig_input_tensor
output, _ = linear(input_tensor)
self.assertEqual(
output.shape,
(self.BATCH_SIZE, hidden_size, feature_size,),
)
# The order of dimension is expected to be (sequence, batch, hidden)
self.assertEqual(output.shape, expected_output_shape)
orig_loss_weight = torch.randn(input_tensor_shape, device="cuda")
if accumulation_in_fp16:
orig_loss_weight = orig_loss_weight.half()
if sequence_parallel_enabled:
loss_weight = orig_loss_weight.chunk(
tensor_model_parallel_world_size, dim=2,
)[parallel_state.get_tensor_model_parallel_rank()]
else:
loss_weight = orig_loss_weight
loss = torch.mul(output, loss_weight).sum()
loss.backward()
with torch.no_grad():
dldy = loss_weight.clone()
x = input_tensor.clone()
a = linear.master_weight.cuda().clone()
dldx = torch.matmul(dldy, a)
self.assertEqual(input_tensor.grad, dldx)
# TODO(mkozuki): Cover the other cases.
if (
tensor_model_parallel_world_size == 1
and not gradient_accumulation_fusion
):
dlda = torch.matmul(torch.transpose(dldy, 1, 2), x).sum(dim=0)
curr_dlda = torch.split(dlda, feature_size_coeff, dim=0)[
parallel_state.get_tensor_model_parallel_rank()
]
self.assertEqual(linear.weight.grad, curr_dlda)
dldy = orig_loss_weight.clone()
x = orig_input_tensor.clone()
ref_linear = nn.Linear(
in_features=self.HIDDEN_SIZE,
out_features=self.HIDDEN_SIZE,
bias=False,
device="cuda",
)
if accumulation_in_fp16:
ref_linear = ref_linear.half()
# NOTE(mkozuki): `master_weight` is available because `keep_master_weight_for_test` is set.
ref_linear.weight.copy_(linear.master_weight)
x.requires_grad_()
expected_output = ref_linear(x)
if sequence_parallel_enabled:
chunk = expected_output.chunk(
tensor_model_parallel_world_size,
dim=2,
)[parallel_state.get_tensor_model_parallel_rank()]
torch.testing.assert_close(
actual=output,
expected=chunk,
)
else:
torch.testing.assert_close(
actual=output,
expected=expected_output,
)
expected_loss = torch.mul(expected_output, dldy).sum()
expected_loss.backward()
grad_attr_name = "main_grad" if gradient_accumulation_fusion else "grad"
# NOTE(mkozuki): Numerical errors seems to be enlarged by tensor model parallel.
if tensor_model_parallel_world_size == 1:
torch.testing.assert_close(
actual=getattr(linear.weight, grad_attr_name),
expected=ref_linear.weight.grad.chunk(
chunks=tensor_model_parallel_world_size,
dim=0,
)[parallel_state.get_tensor_model_parallel_rank()],
)
parallel_state.destroy_model_parallel()
......
......@@ -3,13 +3,13 @@ import logging
import torch
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import mappings
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("apex").setLevel(logging.WARNING)
......@@ -49,7 +49,7 @@ class MappingTestBase:
for rank in range(tensor_model_paralell_world_size)
]
x = torch.cat(tensors, 1)
out = mappings._split(x)
out = mappings._split_along_last_dim(x)
self.assertTrue(
torch.equal(
out, tensors[parallel_state.get_tensor_model_parallel_rank()]
......@@ -68,7 +68,7 @@ class MappingTestBase:
tensor_model_parallel_size_=tensor_model_paralell_world_size
)
device = f"cuda:{self.rank}"
gathered = mappings._gather(
gathered = mappings._gather_along_last_dim(
torch.tensor(
[parallel_state.get_tensor_model_parallel_rank()], device=device
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment