Unverified Commit 3ff1a10f authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

[transformer] Port Sequence Parallelism (takeover of #1396) (#1400)

* it looks possible to remove this file

* add communication collectives

* update Column|RowParallelLinear

* update checkpoint function

* update function name

* parity between public and private collectives

* row parallel linear

* column parallel linear

* sequence parallel: p2p comm

fix typo

* sequence parallel: pipeline parallel

* fix typo

* add layernorm with sequence_parallel_enabled attr

* class variable -> member variable

* fix col parallel test with sequence parallel

* Initial test of `forward_backward_pipelining_without_interleaving` with `model_type=ModelType.encoder_and_decoder`

* add cases pretending to test sequence_parallel

* Apply 2 suggestion(s) to 1 file(s)

* update sequence_parallel_enabled docstring

* update docstring: order of tensor dimensions, sequence_parallel_enabled behavior

* Divide sequence_length if sequence parallel

tensor shape should be updated if sequence parallel is enabled.

* cherry-pick https://github.com/NVIDIA/Megatron-LM/commit/8474e6e54fcb9dfa37aea039352f9fb485fb6f61

* type annotation

* Fix matmul call in RowParallelLinear

Fix `sequence_parallel_enabled` to `False` as you can see in
https://github.com/NVIDIA/Megatron-LM/blob/d898a8991d1a08d29074f87819d1bf41517e35f5/megatron/mpu/layers.py#L511-L514

* update rowparallellinear test

* fix `loss_weight` is not defined in test_layers

* @eqy's comment

* mixed fused layer norm

* fix typo

* misc

* test_layers cleanup

* Skip Bert/GPT script

Since these two models haven't gotten updated for sequence parallle, e.g. the update of the order of dimension from (batch, sequence, feature) to (sequence, batch, feature) and global variables of arguments

* debug part 1/N: comment out `x.retain_grad`

* debug part 2/N: [ColumnParallelLinear] comment out overriding of sequence_parallel_enabled

* debug 3/N: add pipeline test with parallel mlp

* Fix handling `self.input_tensor` and argument

* tp2pp4 ModelType.encoder_or_decoder is failing, which can be at my fault because the backward is blaming the output and the grad_ouptut shape don't match

* revert debug 1/N

* defer tensor model parallel size > 1

* split tensor in sequence dim

* cosmetic

* cosmetic: remove archaic comment

* enable TP>1 for encoder_and_decoder as well

* set requires_grad=True always...

* Set `scatter_gather_tensors_in_pipeline` to :obj:`False`

for the sake of nemo megatron's GPT works with sequence parallel enabled.

* brush up comment of `requires_grad()`

There's a possibility that PyTorch DistributedDataParallel hangs
when some tensor (or parameter) doesn't require grad according to @ptrblck.
This forced `requires_grad` in my understanding is different from that.

* misc changes of scatter_gather_tensors_in_pipeline comment

* guard for torch_ucc

* cosmetic changes related to tests

* update command line arguments

* update TransformerLanguageModel

* rename

* move gpt to gpt.py

* update bert

* add all_gather for params in sequence parallel region

* misc. some diffs were lost during rebasing...

* updates for non sequence parallel execution

* gpt with sequence parallel

* Apply 2 suggestion(s) to 2 file(s)

* update tensor&pipeline parallel size

* why `sequence_parallel_enabled` is not supplied!? Did I messed up when rebasing?

* cosmetic fix

* correct key is sequence_parallel_enabled
parent 57f890a7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from apex.transformer.layers.layer_norm import FastLayerNorm
from apex.transformer.layers.layer_norm import FusedLayerNorm
from apex.transformer.layers.layer_norm import MixedFusedLayerNorm
__all__ = [
"FastLayerNorm",
"FusedLayerNorm",
"MixedFusedLayerNorm",
]
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# NOTE(mkozuki): This file defines two LayerNorm that are compatible with Megatron-LM.
# while avoiding introducing the breaking change of `"sequence_parallel_enabled"` attribute into apex.normalization.FusedLayerNorm
# and apex.contrib.layer_norm.FastLayerNorm.
import warnings
import torch
from apex.normalization import FusedLayerNorm as OrigFusedLayerNorm
from apex.normalization import MixedFusedLayerNorm as OrigMixedFusedLayerNorm
try:
from apex.contrib.layer_norm import FastLayerNorm as OrigFastLayerNorm
except ImportError:
HAS_FAST_LAYER_NORM = False
else:
HAS_FAST_LAYER_NORM = True
__all__ = [
"FusedLayerNorm",
"FastLayerNorm",
"MixedFusedLayerNorm",
]
def _set_sequence_parallel_enabled(
param: torch.Tensor,
sequence_parallel_enabled: bool,
) -> None:
setattr(param, "sequence_parallel_enabled", sequence_parallel_enabled)
class FusedLayerNorm(OrigFusedLayerNorm):
def __init__(
self,
normalized_shape,
eps: float = 1e-5,
elementwise_affine: bool = True,
*,
sequence_parallel_enabled: bool = False,
):
super().__init__(
normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
)
self.sequence_parallel_enabled = sequence_parallel_enabled
if self.elementwise_affine:
_set_sequence_parallel_enabled(self.weight, self.sequence_parallel_enabled)
_set_sequence_parallel_enabled(self.bias, self.sequence_parallel_enabled)
# note: MixedFusedLayerNorm is no different from FusedLayerNorm if it's used in `torch.cuda.amp`.
class MixedFusedLayerNorm(OrigMixedFusedLayerNorm):
def __init__(
self,
normalized_shape,
eps: float = 1e-5,
**kwargs,
) -> None:
self.sequence_parallel_enabled = kwargs.get("sequence_parallel_enabled", False)
super().__init__(normalized_shape=normalized_shape, eps=eps, **kwargs)
if self.sequence_parallel_enabled:
_set_sequence_parallel_enabled(self.weight, self.sequence_parallel_enabled)
_set_sequence_parallel_enabled(self.bias, self.sequence_parallel_enabled)
if HAS_FAST_LAYER_NORM:
class FastLayerNorm(OrigFastLayerNorm):
def __init__(
self,
hidden_size,
eps: float = 1e-5,
*,
sequence_parallel_enabled: bool = False,
):
super().__init__(
hidden_size=hidden_size,
eps=eps
)
self.sequence_parallel_enabled = sequence_parallel_enabled
_set_sequence_parallel_enabled(self.weight, self.sequence_parallel_enabled)
_set_sequence_parallel_enabled(self.bias, self.sequence_parallel_enabled)
else:
class FastLayerNorm(FusedLayerNorm):
def __init__(
self,
hidden_size,
eps: float = 1e-5,
*,
sequence_parallel_enabled: bool = False,
):
warnings.warn("`apex.contrib.layer_norm.FastLayerNorm` isn't available thus falling back to `apex.normalization.FusedLayerNorm`")
super().__init__(
normalized_shape=hidden_size,
eps=eps,
elementwise_affine=True,
sequence_parallel_enabled=sequence_parallel_enabled,
)
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(mkozuki): Consider removing `timers`.
from functools import reduce
import operator
......@@ -20,12 +21,15 @@ from typing import Union, Optional, Tuple
import torch
from apex.transformer import parallel_state
from apex.transformer.log_util import get_transformer_logger
from apex.transformer.utils import split_tensor_into_1d_equal_chunks
from apex.transformer.utils import gather_split_1d_tensor
from apex.transformer.pipeline_parallel.utils import Shape
from apex.transformer.pipeline_parallel._timers import _Timers
_logger = get_transformer_logger(__name__)
class FutureTensor:
def __init__(self, tensor: torch.Tensor, waitfunc):
......@@ -42,11 +46,11 @@ class FutureTensor:
def _run_p2pops(
tensor_send_prev: Union[torch.Tensor, None],
tensor_send_next: Union[torch.Tensor, None],
tensor_recv_prev: Union[torch.Tensor, None],
tensor_recv_next: Union[torch.Tensor, None],
async_comm: bool = False
tensor_send_prev: Union[torch.Tensor, None],
tensor_send_next: Union[torch.Tensor, None],
tensor_recv_prev: Union[torch.Tensor, None],
tensor_recv_next: Union[torch.Tensor, None],
async_comm: bool = False
):
ops = []
if tensor_send_prev is not None:
......@@ -93,6 +97,11 @@ def _run_p2pops(
return (None, None, None, None)
# TODO(mkozuki): Check if it's possible to sunset `override_scatter_gather_tensors_in_pipeline`.
# TODO(mkozuki): Think about if it's possible to push some logic and arguments e.g.
# `scatter_gather_tensors_in_pipeline`, `sequence_parallel_enabled`, and
# `override_scatter_gather_tensors_in_pipeline` # to the user of
# apex.transformer forward_backwardfunctions.
def _communicate(
tensor_send_next: Optional[torch.Tensor],
tensor_send_prev: Optional[torch.Tensor],
......@@ -106,9 +115,14 @@ def _communicate(
params_dtype: Optional[torch.dtype] = None,
fp32_residual_connection: bool = False,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
) -> Tuple[Union[torch.Tensor, FutureTensor, None], Union[torch.Tensor, FutureTensor, None]]:
"""Base function for communication of tensors between stages.
.. note::
Reference https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/cfd2e2160700b7f2c1bf35298ac14bc341f4c759/megatron/p2p_communication.py#L24-L159
dtype logic: If none of ``dtype_``, ``params_dtype``, ``fp32_residual_connection`` is specified,
torch.float32 is used.
......@@ -130,6 +144,9 @@ def _communicate(
params_dtype: Optional and legacy. Defaults to torch.float. If you manually call `.half()` or `.bfloat16()` on
your model deliberately, pass this argument.
fp32_residual_connection: Optional. If :obj:`True`, move residual connections to fp32.
sequence_parallel_enabled: Set to :obj:`True` if sequence parallel is enabled.
This argument is here for consistency with Megatron-LM.
This argument has an effect on the communication optimization, not on tensor_shape update.
Returns:
tuple containing
......@@ -137,6 +154,13 @@ def _communicate(
- tensor_recv_prev: `torch.Tensor` if `recv_prev` is :obj:`True`, `None` otherwise.
- tensor_recv_next: `torch.Tensor` if `recv_next` is :obj:`True`, `None` otherwise.
"""
if async_comm and sequence_parallel_enabled:
import warnings # NOQA
class ExperimentalWarning(UserWarning): pass # NOQA
warnings.warn(
"The combination of `async_comm` and `sequence_parallel_enabled` is not well tested.",
ExperimentalWarning,
)
# Create placeholder tensors for receive in forward and backward directions if needed.
tensor_recv_prev = None
tensor_recv_next = None
......@@ -144,25 +168,45 @@ def _communicate(
# In megatron, `tensor_shape` is set to `(args.seq_length, args.micro_batch_size, args.hidden_size)`
raise RuntimeError(
"`tensor_shape` must be specified. Common `tensor_shape` is `(seq_length, micro_batch_size, hidden_size)`")
if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = (reduce(operator.mul, tensor_shape, 1) // parallel_state.get_tensor_model_parallel_world_size(),)
tensor_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
override_scatter_gather_tensors_in_pipeline_ = False
# TODO(mkozuki): Demystify hardcode False of `scatter_gather_tensors_in_pipeline` and add a testcase if possible.
# NOTE(mkozuki): This is super strange and doesn't make sense to me. I have no idea what is happening here.
# However, I can say that this hardcoding override is necessary for sequence parallel in nemo megatron to work.
# I've not managed to reproduce the hang using standalone GPT with sequence parallel.
# The hang in NeMo Megatron happens in the 3rd iteration, the last iteration of stead phase inside
# forward_backward_pipelining_without_interleaving, pipeline parallel rank of 0 (tensor model parallel world
# size of 2 and pipeline model parallel world size of 2). The commit then of APEX and NeMo were
# https://github.com/NVIDIA/apex/pull/1396/commits/3060c98dd8ba42abf7702ea9d2cff0f39ea74f45 and
# https://github.com/NVIDIA/NeMo/pull/4232/commits/1cb32dfca2ab9b20f53ebdb84476c34cb42f0205.
# The PyTorch version was 1.13.0a0+git2d354cd, for what is worth.
# Currently, indiscriminately this is set to `False`, which can lead to an unexpected performance regression
# for non sequence parallel case.
scatter_gather_tensors_in_pipeline = False
if scatter_gather_tensors_in_pipeline and not sequence_parallel_enabled:
tensor_chunk_size = int(reduce(operator.mul, tensor_shape, 1))
if tensor_chunk_size % tensor_parallel_size == 0:
tensor_chunk_shape = [tensor_chunk_size // tensor_parallel_size]
else:
tensor_chunk_shape = tensor_shape
override_scatter_gather_tensors_in_pipeline_ = True
else:
tensor_chunk_shape = tensor_shape
# The dtype logic below is copied from NVIDIA/Megatron-LM repo:
# https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/p2p_communication.py#L74-L81
# NOTE (mkozuki): Currently NeMo is implementing APEX AMP O2 style using PyTorch. In O2 style, forcing p2p comm to
# use FP32 will be a perf killer so that I decided to reanimate `dtype_` argument with the default value of `None`.
# NOTE (mkozuki): In PyTorch AMP, i.e. `torch.cuda.amp.autocast` context, activation tensors can be either FP32,
# FP16, or BF16 and there's no way to tell the dtypes of tensors on different devices in general.
# It might be possible if we restrict model architecture.
dtype = params_dtype or torch.float
if fp32_residual_connection:
dtype = torch.float
requires_grad = True
if dtype_ is not None:
dtype = dtype_
requires_grad = False
# TODO(mkozuki): Figure out why this logic of requires_grad isn't working
# when sequence_parallel_enabled=True. Otherwise, `x.retain_grad()` of
# https://github.com/crcrpar/apex/blob/069832078a652b4bd8a99db84faf953a81415ab3/apex/transformer/pipeline_parallel/schedules/common.py#L360
# fails.
# requires_grad = False
if recv_prev:
tensor_recv_prev = torch.empty(
......@@ -180,7 +224,12 @@ def _communicate(
)
# Split tensor into smaller chunks if using scatter-gather optimization.
if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
scatter_gather_optimization_doable = (
not override_scatter_gather_tensors_in_pipeline_
and scatter_gather_tensors_in_pipeline
and not sequence_parallel_enabled
)
if scatter_gather_optimization_doable:
if tensor_send_next is not None:
tensor_send_next = split_tensor_into_1d_equal_chunks(tensor_send_next)
......@@ -210,7 +259,7 @@ def _communicate(
torch.cuda.synchronize()
# If using scatter-gather optimization, gather smaller chunks.
if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
if scatter_gather_optimization_doable:
if not async_comm:
if recv_prev:
tensor_recv_prev = (
......@@ -218,7 +267,7 @@ def _communicate(
.view(tensor_shape)
.requires_grad_()
)
if recv_next:
tensor_recv_next = (
gather_split_1d_tensor(tensor_recv_next)
......@@ -254,17 +303,17 @@ def _communicate(
if tensor_recv_next is not None:
future_tensor_recv_next = FutureTensor(tensor_recv_next, tensor_recv_next_waitfunc)
return future_tensor_recv_prev, future_tensor_recv_next
return tensor_recv_prev, tensor_recv_next
def recv_forward(
tensor_shape: Shape,
override_scatter_gather_tensors_in_pipeline: bool = False,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
async_comm: bool = False,
tensor_shape: Shape,
override_scatter_gather_tensors_in_pipeline: bool = False,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Receive tensor from previous rank in pipeline (forward receive)."""
if parallel_state.is_pipeline_first_stage():
......@@ -280,6 +329,7 @@ def recv_forward(
override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline,
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
# if timers is not None:
# timers("forward-recv").stop()
......@@ -287,11 +337,12 @@ def recv_forward(
def recv_backward(
tensor_shape: Shape = None,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
async_comm: bool = False,
tensor_shape: Shape = None,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Receive tensor from next rank in pipeline (backward receive)."""
if parallel_state.is_pipeline_last_stage():
......@@ -306,6 +357,7 @@ def recv_backward(
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
# if timers is not None:
# timers("backward-recv").stop()
......@@ -313,13 +365,14 @@ def recv_backward(
def send_forward(
output_tensor: torch.Tensor,
override_scatter_gather_tensors_in_pipeline: bool = False,
tensor_shape: Shape = None,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
async_comm: bool = False,
output_tensor: torch.Tensor,
override_scatter_gather_tensors_in_pipeline: bool = False,
tensor_shape: Shape = None,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> None:
"""Send tensor to next rank in pipeline (forward send)."""
if parallel_state.is_pipeline_last_stage():
......@@ -335,19 +388,20 @@ def send_forward(
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
# if timers is not None:
# timers("forward-send").stop()
def send_backward(
input_tensor_grad: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
async_comm: bool = False,
input_tensor_grad: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> None:
"""Send tensor to previous rank in pipeline (backward send)."""
if parallel_state.is_pipeline_first_stage():
......@@ -362,18 +416,20 @@ def send_backward(
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
# if timers is not None:
# timers("backward-send").stop()
def send_forward_recv_backward(
output_tensor: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
async_comm: bool = False,
output_tensor: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Batched send and recv with next rank in pipeline."""
if parallel_state.is_pipeline_last_stage():
......@@ -388,6 +444,7 @@ def send_forward_recv_backward(
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
# if timers is not None:
# timers("forward-send-backward-recv").stop()
......@@ -395,12 +452,13 @@ def send_forward_recv_backward(
def send_backward_recv_forward(
input_tensor_grad: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
async_comm: bool = False,
input_tensor_grad: torch.Tensor,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Batched send and recv with previous rank in pipeline."""
if parallel_state.is_pipeline_first_stage():
......@@ -415,6 +473,7 @@ def send_backward_recv_forward(
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
# if timers is not None:
# timers("backward-send-forward-recv").stop()
......@@ -422,13 +481,14 @@ def send_backward_recv_forward(
def send_forward_recv_forward(
output_tensor: torch.Tensor,
recv_prev: bool,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
async_comm: bool = False,
output_tensor: torch.Tensor,
recv_prev: bool,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor]:
"""Batched recv from previous rank and send to next rank in pipeline."""
# if timers is not None:
......@@ -441,6 +501,7 @@ def send_forward_recv_forward(
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
# if timers is not None:
# timers("forward-send-forward-recv").stop()
......@@ -448,13 +509,14 @@ def send_forward_recv_forward(
def send_backward_recv_backward(
input_tensor_grad: torch.Tensor,
recv_next: bool,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
async_comm: bool = False,
input_tensor_grad: torch.Tensor,
recv_next: bool,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor]:
"""Batched recv from next rank and send to previous rank in pipeline."""
# if timers is not None:
......@@ -467,6 +529,7 @@ def send_backward_recv_backward(
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
# if timers is not None:
# timers("backward-send-backward-recv").stop()
......@@ -474,15 +537,16 @@ def send_backward_recv_backward(
def send_forward_backward_recv_forward_backward(
output_tensor: torch.Tensor,
input_tensor_grad: torch.Tensor,
recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
async_comm: bool = False,
output_tensor: torch.Tensor,
input_tensor_grad: torch.Tensor,
recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
timers: _Timers = None,
) -> Tuple[Union[torch.Tensor, FutureTensor], Union[torch.Tensor, FutureTensor]]:
"""Batched send and recv with previous and next ranks in pipeline."""
# if timers is not None:
......@@ -495,6 +559,7 @@ def send_forward_backward_recv_forward_backward(
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
# if timers is not None:
# timers("forward-backward-send-forward-backward-recv").stop()
......
......@@ -34,6 +34,7 @@ def _forward_backward_pipelining_with_interleaving(
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False,
deallocate_pipeline_outputs: bool = False,
sequence_parallel_enabled: bool = False,
**kwargs,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run interleaved 1F1B schedule with communication between pipeline stages as needed.
......@@ -57,13 +58,17 @@ def _forward_backward_pipelining_with_interleaving(
Keyword args:
forward_only:
tensor_shape: Shape of tensor.
tensor_shape: Shape of tensor. The tensor is expected to be 3D and its order of dimension
is supposed to be ``(sequence, batch, hidden)``.
dtype: dtype used in p2p communication. If ``None`` (default value),
torch.float32 will be used even if ``autocast`` is enabled.
grad_scaler:
disable_autocast:
deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of
each pipeline stage. Experimental.
sequence_parallel_enabled: Set to :obj:`True` for this function to handle sequence length.
When :obj:`True`, the sequence length on each tensor model parallel rank is updated
to :math:`original\_sequence\_length / tensor\_model\_parallel\_world\_size`.
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
......@@ -77,6 +82,15 @@ def _forward_backward_pipelining_with_interleaving(
"This option is not recommended."
)
# mypy will blame the following if statement
if sequence_parallel_enabled:
seq_length, batch_size, hidden = tensor_shape
tensor_shape = (
seq_length // parallel_state.get_tensor_model_parallel_world_size(),
batch_size,
hidden,
)
num_model_chunks: int = len(model)
input_tensors: List[List[Union[None, torch.Tensor]]] = [
[] for _ in range(num_model_chunks)
......@@ -201,7 +215,11 @@ def _forward_backward_pipelining_with_interleaving(
###################################################################################################################
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(
p2p_communication.recv_forward(tensor_shape=tensor_shape, dtype=dtype)
p2p_communication.recv_forward(
tensor_shape=tensor_shape,
dtype=dtype,
sequence_parallel_enabled=sequence_parallel_enabled,
)
)
_logger.info("Warmup phase")
for k in range(num_warmup_microbatches):
......@@ -247,6 +265,7 @@ def _forward_backward_pipelining_with_interleaving(
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype=dtype,
sequence_parallel_enabled=sequence_parallel_enabled,
)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
else:
......@@ -256,9 +275,10 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev=recv_prev,
tensor_shape=tensor_shape,
dtype=dtype,
sequence_parallel_enabled=sequence_parallel_enabled,
)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)
###################################################################################################################
# Run 1F1B in steady state.
......@@ -339,6 +359,7 @@ def _forward_backward_pipelining_with_interleaving(
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype=dtype,
sequence_parallel_enabled=sequence_parallel_enabled,
)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)
......@@ -356,7 +377,11 @@ def _forward_backward_pipelining_with_interleaving(
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks - 1].append(
p2p_communication.recv_backward(tensor_shape=tensor_shape, dtype=dtype)
p2p_communication.recv_backward(
tensor_shape=tensor_shape,
dtype=dtype,
sequence_parallel_enabled=sequence_parallel_enabled,
)
)
for k in range(num_microbatches_remaining, num_microbatches):
_logger.debug(
......@@ -376,6 +401,7 @@ def _forward_backward_pipelining_with_interleaving(
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype=dtype,
sequence_parallel_enabled=sequence_parallel_enabled,
)
)
......
......@@ -31,7 +31,19 @@ def get_tensor_shapes(
*,
tensor_shape: Union[List[int], torch.Size],
decoder_sequence_length: Optional[int] = None,
sequence_parallel_enabled: bool = False,
) -> Sequence[Sequence[int]]:
"""Get tensors shapes
Args:
rank: pipeline parallel rank
model_type:
Keyword Args:
tensor_shape:
decoder_sequence_length:
sequence_parallel_enabled:
"""
# Determine right tensor sizes (based on position of rank with respect to split
# rank) and model size.
# Send two tensors if model is T5 and rank is in decoder stage:
......@@ -44,21 +56,27 @@ def get_tensor_shapes(
len(tensor_shape) == 3
), f"`tensor_shape` should be [sequence_length, micro_batch_size, hidden_size] but {tensor_shape}"
sequence_length, micro_batch_size, hidden_size = tensor_shape
seq_len = sequence_length
if sequence_parallel_enabled:
seq_len = sequence_length // parallel_state.get_tensor_model_parallel_world_size()
tensor_shapes = []
if model_type == ModelType.encoder_and_decoder:
if decoder_sequence_length is None:
raise ValueError("`decoder_sequence_length` is required for `ModelType.encoder_and_decoder`")
dec_seq_len = decoder_sequence_length
if sequence_parallel_enabled:
dec_seq_len = decoder_sequence_length // parallel_state.get_tensor_model_parallel_world_size()
if parallel_state.is_pipeline_stage_before_split(rank):
# If next rank is after split, then need transpose for encoder_hidden_state.
if parallel_state.is_pipeline_stage_before_split(rank + 1):
tensor_shapes.append((sequence_length, micro_batch_size, hidden_size))
tensor_shapes.append((seq_len, micro_batch_size, hidden_size))
else:
tensor_shapes.append((micro_batch_size, sequence_length, hidden_size))
tensor_shapes.append((dec_seq_len, micro_batch_size, hidden_size))
else:
tensor_shapes.append((decoder_sequence_length, micro_batch_size, hidden_size))
tensor_shapes.append((micro_batch_size, sequence_length, hidden_size))
tensor_shapes.append((dec_seq_len, micro_batch_size, hidden_size))
tensor_shapes.append((seq_len, micro_batch_size, hidden_size))
else:
tensor_shapes.append((sequence_length, micro_batch_size, hidden_size))
tensor_shapes.append((seq_len, micro_batch_size, hidden_size))
return tensor_shapes
......@@ -67,13 +85,21 @@ def recv_forward(
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
) -> List[Union[None, torch.Tensor, FutureTensor]]:
input_tensors = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
input_tensors.append(None)
else:
input_tensors.append(p2p_communication.recv_forward(tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm))
input_tensors.append(
p2p_communication.recv_forward(
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
)
return input_tensors
......@@ -82,13 +108,21 @@ def recv_backward(
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
) -> List[Union[None, torch.Tensor, FutureTensor]]:
output_tensor_grads = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
output_tensor_grads.append(None)
else:
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm))
output_tensor_grads.append(
p2p_communication.recv_backward(
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
)
return output_tensor_grads
......@@ -98,13 +132,20 @@ def send_forward(
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
) -> None:
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm)
p2p_communication.send_forward(
output_tensor,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
def send_backward(
......@@ -113,13 +154,20 @@ def send_backward(
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
) -> None:
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm)
p2p_communication.send_backward(
input_tensor_grad,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
def send_forward_recv_backward(
......@@ -128,6 +176,7 @@ def send_forward_recv_backward(
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
) -> List[Union[None, torch.Tensor, FutureTensor]]:
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
......@@ -136,7 +185,13 @@ def send_forward_recv_backward(
if tensor_shape is None:
output_tensor_grads.append(None)
continue
output_tensor_grad = p2p_communication.send_forward_recv_backward(output_tensor, tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm)
output_tensor_grad = p2p_communication.send_forward_recv_backward(
output_tensor,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
output_tensor_grads.append(output_tensor_grad)
return output_tensor_grads
......@@ -147,6 +202,7 @@ def send_backward_recv_forward(
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
) -> List[Union[None, torch.Tensor, FutureTensor]]:
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
......@@ -155,7 +211,13 @@ def send_backward_recv_forward(
if tensor_shape is None:
input_tensors.append(None)
continue
input_tensor = p2p_communication.send_backward_recv_forward(input_tensor_grad, tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm)
input_tensor = p2p_communication.send_backward_recv_forward(
input_tensor_grad,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
input_tensors.append(input_tensor)
return input_tensors
......@@ -173,6 +235,7 @@ def forward_backward_pipelining_without_interleaving(
disable_autocast: bool = False,
deallocate_pipeline_outputs: bool = False,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
**kwargs,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run non-interleaved 1F1B schedule, with communication between pipeline stages.
......@@ -192,13 +255,17 @@ def forward_backward_pipelining_without_interleaving(
Keyword args:
forward_only:
tensor_shape: Shape of tensor. Required for P2P communication.
tensor_shape: Shape of tensor. The tensor is expected to be 3D and its order of dimension
is supposed to be ``(sequence, batch, hidden)``.
dtype: dtype used in p2p communication. If ``None`` (default value),
torch.float32 will be used even if ``autocast`` is enabled.
grad_scaler:
disable_autocast:
deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of
each pipeline stage. Experimental.
sequence_parallel_enabled: Set to :obj:`True` for this function to handle sequence length.
When :obj:`True`, the sequence length on each tensor model parallel rank is updated
to :math:`original\_sequence\_length / tensor\_model\_parallel\_world\_size`.
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
......@@ -228,10 +295,18 @@ def forward_backward_pipelining_without_interleaving(
model_type = get_model_type(model)
rank: int = parallel_state.get_pipeline_model_parallel_rank()
recv_tensor_shapes: List[List[int]] = get_tensor_shapes(
rank - 1, model_type, tensor_shape=tensor_shape, decoder_sequence_length=decoder_sequence_length
rank - 1,
model_type,
tensor_shape=tensor_shape,
decoder_sequence_length=decoder_sequence_length,
sequence_parallel_enabled=sequence_parallel_enabled,
)
send_tensor_shapes: List[List[int]] = get_tensor_shapes(
rank, model_type, tensor_shape=tensor_shape, decoder_sequence_length=decoder_sequence_length
rank,
model_type,
tensor_shape=tensor_shape,
decoder_sequence_length=decoder_sequence_length,
sequence_parallel_enabled=sequence_parallel_enabled,
)
_logger.info(
......@@ -251,7 +326,12 @@ def forward_backward_pipelining_without_interleaving(
for i in range(num_warmup_microbatches):
_logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}")
_logger.debug("receive fwd")
input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm)
input_tensor = recv_forward(
tensor_shapes=recv_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i)
output_tensor = forward_step(
forward_step_func,
......@@ -263,7 +343,13 @@ def forward_backward_pipelining_without_interleaving(
disable_autocast,
)
_logger.debug("send fwd")
send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype, async_comm=async_comm)
send_forward(
output_tensor,
tensor_shapes=send_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
if not forward_only:
input_tensors.append(input_tensor)
......@@ -297,15 +383,32 @@ def forward_backward_pipelining_without_interleaving(
)
if forward_only:
_logger.debug("send fwd")
send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype, async_comm=async_comm)
send_forward(
output_tensor,
tensor_shapes=send_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
if not last_iteration:
_logger.debug("receive fwd (last iteration)")
input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm)
input_tensor = recv_forward(
tensor_shapes=recv_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
else:
_logger.debug("send fwd & receive bwd")
output_tensor_grad = send_forward_recv_backward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype, async_comm=async_comm)
output_tensor_grad = send_forward_recv_backward(
output_tensor,
tensor_shapes=send_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
......@@ -328,10 +431,22 @@ def forward_backward_pipelining_without_interleaving(
if last_iteration:
input_tensor = None
_logger.debug("send bwd")
send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm)
send_backward(
input_tensor_grad,
tensor_shapes=recv_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
else:
_logger.debug("send bwd and receive fwd")
input_tensor = send_backward_recv_forward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm)
input_tensor = send_backward_recv_forward(
input_tensor_grad,
tensor_shapes=recv_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
###################################################################################################################
# Run cooldown backward passes.
###################################################################################################################
......@@ -343,7 +458,12 @@ def forward_backward_pipelining_without_interleaving(
output_tensor = output_tensors.pop(0)
_logger.debug("receive bwd")
output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes, dtype=dtype, async_comm=async_comm)
output_tensor_grad = recv_backward(
tensor_shapes=send_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
input_tensor_grad = backward_step(
input_tensor,
......@@ -355,6 +475,12 @@ def forward_backward_pipelining_without_interleaving(
)
_logger.debug("send bwd")
send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm)
send_backward(
input_tensor_grad,
tensor_shapes=recv_tensor_shapes,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
return losses_reduced
......@@ -32,6 +32,7 @@ from apex.transformer.tensor_parallel.mappings import (
gather_from_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region,
scatter_to_sequence_parallel_region,
)
from .random import (
......@@ -62,6 +63,7 @@ __all__ = [
"gather_from_tensor_model_parallel_region",
"reduce_from_tensor_model_parallel_region",
"scatter_to_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region",
# random.py
"checkpoint",
"get_cuda_rng_tracker",
......
This diff is collapsed.
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -20,7 +20,7 @@ from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from apex.transformer.tensor_parallel.utils import split_tensor_along_last_dim
def _reduce(input_):
def _reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
......@@ -33,7 +33,7 @@ def _reduce(input_):
return input_
def _split(input_):
def _split_along_last_dim(input_: torch.Tensor) -> torch.Tensor:
"""Split the tensor along its last dimension and keep the
corresponding slice."""
......@@ -52,8 +52,24 @@ def _split(input_):
return output
def _gather(input_):
"""Gather tensors and concatinate along the last dimension."""
def _split_along_first_dim(input_: torch.Tensor) -> torch.Tensor:
"""Split the tensor along its first dimension and keep the corresponding slice."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU for tensor model parallel.
if world_size == 1:
return input_
# Split along first dimension.
dim_size = input_.size(0)
assert dim_size % world_size == 0
local_dim_size = dim_size // world_size
dim_offset = get_tensor_model_parallel_rank() * local_dim_size
output = input_[dim_offset:dim_offset + local_dim_size].contiguous()
return output
def _gather_along_last_dim(input_: torch.Tensor) -> torch.Tensor:
"""Gather tensors and concatenate along the last dimension."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
......@@ -76,9 +92,57 @@ def _gather(input_):
return output
def _gather_along_first_dim(input_: torch.Tensor) -> torch.Tensor:
"""Gather tensors and concatenate along the first dimension."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
shape = list(input_.shape)
shape[0] *= world_size
output = torch.empty(shape, dtype=input_.dtype, device=torch.cuda.current_device())
# Original implementation uses `_all_gather_base` as follows.
# Deliberately keep the comment-out for reference because
# I'd love to switch to this API once this gets public/stable.
# torch.distributed._all_gather_base(output, input_.contiguous(), group=get_tensor_model_parallel_group())
torch.distributed.all_gather(
list(output.chunk(world_size)),
input_.contiguous(),
group=get_tensor_model_parallel_group(),
)
return output
def _reduce_scatter_along_first_dim(input_: torch.Tensor) -> torch.Tensor:
"""Reduce-scatter the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
shape = list(input_.shape)
assert shape[0] % world_size == 0
shape[0] //= world_size
output = torch.empty(shape, dtype=input_.dtype, device=torch.cuda.current_device())
# Original implementation uses `_reduce_scatter_base` as follows.
# Deliberately keep the comment-out for reference because
# I'd love to switch to this API once this gets public/stable.
# torch.distributed._reduce_scatter_base(output, input_.contiguous(), group=get_tensor_model_parallel_group())
torch.distributed.reduce_scatter(
output,
list(input_.contiguous().chunk(world_size)),
group=get_tensor_model_parallel_group(),
)
return output
class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""
"""Pass the input to the tensor model parallel region."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@staticmethod
def symbolic(graph, input_):
return input_
......@@ -93,8 +157,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
"""All-reduce the input from the tensor model parallel region."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@staticmethod
def symbolic(graph, input_):
return _reduce(input_)
......@@ -111,33 +177,95 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@staticmethod
def symbolic(graph, input_):
return _split(input_)
return _split_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _split(input_)
return _split_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output)
return _gather_along_last_dim(grad_output)
class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
"""Gather the input from tensor model parallel region and concatenate."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@staticmethod
def symbolic(graph, input_):
return _gather_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _gather_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _split_along_last_dim(grad_output)
class _ScatterToSequenceParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chunk to the rank."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@staticmethod
def symbolic(graph, input_):
return _gather(input_)
return _split_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _gather(input_)
return _split_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output)
return _gather_along_first_dim(grad_output)
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatenate."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@staticmethod
def symbolic(graph, input_, to_model_parallel: bool = True):
return _gather_along_first_dim(input_)
@staticmethod
def forward(ctx, input_, to_model_parallel: bool = True):
ctx.to_model_parallel = to_model_parallel
return _gather_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
if ctx.to_model_parallel:
return _reduce_scatter_along_first_dim(grad_output), None
else:
return _split_along_first_dim(grad_output), None
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
"""Reduce scatter the input from the sequence parallel region and concatenate."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@staticmethod
def symbolic(graph, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
# -----------------
......@@ -145,17 +273,40 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
# -----------------
def copy_to_tensor_model_parallel_region(input_):
def copy_to_tensor_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
return _CopyToModelParallelRegion.apply(input_)
def reduce_from_tensor_model_parallel_region(input_):
def reduce_from_tensor_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_tensor_model_parallel_region(input_):
def scatter_to_tensor_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
return _ScatterToModelParallelRegion.apply(input_)
def gather_from_tensor_model_parallel_region(input_):
def gather_from_tensor_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
return _GatherFromModelParallelRegion.apply(input_)
def scatter_to_sequence_parallel_region(input_: torch.Tensor) -> torch.Tensor:
return _ScatterToSequenceParallelRegion.apply(input_)
def gather_from_sequence_parallel_region(input_: torch.Tensor, to_model_parallel: bool = True) -> torch.Tensor:
return _GatherFromSequenceParallelRegion.apply(input_, to_model_parallel)
def reduce_scatter_to_sequence_parallel_region(input_: torch.Tensor) -> torch.Tensor:
return _ReduceScatterToSequenceParallelRegion.apply(input_)
__all__ = [
"copy_to_tensor_model_parallel_region",
"reduce_from_tensor_model_parallel_region",
"scatter_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region",
"gather_from_sequence_parallel_region",
"reduce_scatter_to_sequence_parallel_region",
]
......@@ -13,6 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(mkozuki): Remove this file as Megatron-LM seems to have done so.
import torch
......
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,9 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO (mkozuki): Audit this file.
# I don't think some functions strongly relate to `random` in tensor_parallel.
# Rather, some functions are mainly for gradient checkpointing (torch.utils.checkpoint).
# NOTE(mkozuki): This file is based on megatron-lm/mpu/random.py with some differences:
# - Not using "viewless" tensor:
# - _kernel_make_viewless_tensor
# - MakeViewlessTensor
# - make_viewless_tensor
# - assert_viewless_tensor
# - safely_set_viewless_tensor_data
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
......@@ -35,13 +39,12 @@ from apex.transformer.utils import gather_split_1d_tensor
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = "model-parallel-rng"
# TODO(mkozuki): Remove `_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER` as megatron-lm doesn't seem to use.
# Whether apply model parallelism to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None
# TODO (mkozuki): Consider the possibility of removing `tensor_model_parallel_size`,
# `get_tensor_model_parallel_world_size()` might be alternative.
# TODO(mkozuki): Remove `init_checkpointed_activations_memory_buffer` as megatron-lm doesn't seem to use.
def init_checkpointed_activations_memory_buffer(
micro_batch_size,
max_position_embeddings,
......@@ -77,6 +80,7 @@ def init_checkpointed_activations_memory_buffer(
)
# TODO(mkozuki): Remove `reset_checkpointed_activations_memory_buffer` as megatron-lm doesn't seem to use.
def reset_checkpointed_activations_memory_buffer():
"""Reset the memory used for checkpointing."""
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
......@@ -86,7 +90,7 @@ def reset_checkpointed_activations_memory_buffer():
def _set_cuda_rng_state(new_state, device=-1):
"""Sets the random number generator state of the current GPU.
Argumentss:
Arguments:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
......@@ -239,8 +243,9 @@ class CheckpointFunction(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, run_function, *args):
def forward(ctx, run_function, distribute_saved_activations, *args):
ctx.run_function = run_function
ctx.distribute_saved_activations = distribute_saved_activations
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
......@@ -252,10 +257,8 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
ctx.input_0_shape = args[0].data.shape
args[0].data = split_tensor_into_1d_equal_chunks(args[0].data)
args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add(args[0].data)
if ctx.distribute_saved_activations:
ctx.input_0_shape = args[0].shape
# Store everything.
ctx.save_for_backward(*args)
......@@ -269,9 +272,6 @@ class CheckpointFunction(torch.autograd.Function):
"please use .backward() if possible"
)
inputs = ctx.saved_tensors
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
inputs[0].data = gather_split_1d_tensor(inputs[0].data)
inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
......@@ -300,10 +300,12 @@ class CheckpointFunction(torch.autograd.Function):
inp.grad if isinstance(inp, torch.Tensor) else inp
for inp in detached_inputs
)
return (None,) + grads
return (None, None) + grads
def checkpoint(function, *args):
# NOTE(mkozuki): It doesn't look like `distribute_saved_activations` is used in apex.transformer
# but I added this change to reduce the superficial difference from Megatron-LM.
def checkpoint(function, distribute_saved_activations, *args):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
return CheckpointFunction.apply(function, *args)
return CheckpointFunction.apply(function, distribute_saved_activations, *args)
......@@ -12,12 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Sequence
import torch
from apex.transformer.utils import divide
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]:
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
......@@ -44,14 +50,14 @@ class VocabUtility:
@staticmethod
def vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size
):
per_partition_vocab_size: int, rank, world_size: int
) -> Sequence[int]:
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
@staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size
......
......@@ -39,9 +39,13 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_data_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_biencoder_args(parser)
parser = _add_vit_args(parser)
parser = _add_vision_args(parser)
parser = _add_logging_args(parser)
# NOTE(mkozuki): This option is added to investigate the potential of `torch.autograd.graph.save_on_cpu()`.
# ref: https://pytorch.org/docs/stable/autograd.html#torch.autograd.graph.save_on_cpu.
parser.add_argument('--cpu-offload', action='store_true', default=False, help='Turns on CPU offloading')
# Custom arguments.
if extra_args_provider is not None:
parser = extra_args_provider(parser)
......@@ -65,6 +69,11 @@ def parse_args(extra_args_provider=None, defaults={},
args.pipeline_model_parallel_size = min(
args.pipeline_model_parallel_size,
(args.world_size // args.tensor_model_parallel_size))
args.transformer_pipeline_model_parallel_size = (
args.pipeline_model_parallel_size - 1
if args.standalone_embedding_stage else
args.pipeline_model_parallel_size
)
# Checks.
model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size
......@@ -98,13 +107,18 @@ def parse_args(extra_args_provider=None, defaults={},
'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size
if args.checkpoint_activations:
args.activations_checkpoint_method = 'uniform'
args.recompute_granularity = 'full'
args.recompute_method = 'uniform'
if args.rank == 0:
print('--checkpoint-activations is no longer valid, '
'use --activation-checkpoint-method instead. '
'Defaulting to activation-checkpoint-method=uniform.')
'use --recompute-granularity and --recompute-method instead. '
'Defaulting to recompute-granularity=full and recompute-method=uniform.')
del args.checkpoint_activations
if args.recompute_activations:
args.recompute_granularity = 'selective'
del args.recompute_activations
# Set input defaults.
for key in defaults:
# For default to be valid, it should not be provided in the
......@@ -166,6 +180,14 @@ def parse_args(extra_args_provider=None, defaults={},
if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local'
assert args.use_contiguous_buffers_in_local_ddp
else:
if args.gradient_accumulation_fusion:
args.gradient_accumulation_fusion = False
if args.rank == 0:
print('Gradient accumulation fusion to linear layer weight '
'gradient computation is supported only with fp32 '
'gradient accumulation. Setting gradient_accumulation_fusion '
'to False', flush=True)
# For torch DDP, we do not use contiguous buffer
if args.DDP_impl == 'torch':
......@@ -244,17 +266,51 @@ def parse_args(extra_args_provider=None, defaults={},
if args.fp32_residual_connection:
assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing.
if args.distribute_checkpointed_activations:
if args.weight_decay_incr_style == 'constant':
assert args.start_weight_decay is None
assert args.end_weight_decay is None
args.start_weight_decay = args.weight_decay
args.end_weight_decay = args.weight_decay
else:
assert args.start_weight_decay is not None
assert args.end_weight_decay is not None
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
# Persistent fused layer norm.
if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11):
args.no_persist_layer_norm = True
if args.rank == 0:
print('Persistent fused layer norm kernel is supported from '
'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
'Defaulting to no_persist_layer_norm=True')
# Activation recomputing.
if args.distribute_saved_activations:
assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'checkpointed activations only across tensor model ' \
'recomputed activations only across tensor model ' \
'parallel groups'
assert args.activations_checkpoint_method is not None, \
'for distribute-checkpointed-activations to work you '\
'need to use a activation-checkpoint method '
assert args.num_layers_per_virtual_pipeline_stage is None, \
'currently distrobuted checkpoint activations only supported for ' \
'nointerleaved pipeline parallelism'
assert args.recompute_granularity == 'full', \
'distributed recompute activations is only '\
'application to full recompute granularity'
assert args.recompute_method is not None, \
'for distributed recompute activations to work you '\
'need to use a recompute method '
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \
'distributed recompute activations are supported for pytorch ' \
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
if args.recompute_granularity == 'selective':
assert args.recompute_method is None, \
'recompute method is not yet supported for ' \
'selective recomputing granularity'
# disable async_tensor_model_parallel_allreduce when
# model parallel memory optimization is enabled
if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False
_print_args(args)
return args
......@@ -279,6 +335,18 @@ def _check_arg_is_not_none(args, arg):
assert getattr(args, arg) is not None, '{} argument is None'.format(arg)
def _add_inference_args(parser):
group = parser.add_argument_group(title='inference')
group.add_argument('--inference-batch-times-seqlen-threshold',
type=int, default=512,
help='During inference, if batch-size times '
'sequence-length is smaller than this threshold '
'then we will not use pipelining, otherwise we will.')
return parser
def _add_network_size_args(parser):
group = parser.add_argument_group(title='network size')
......@@ -318,6 +386,8 @@ def _add_network_size_args(parser):
group.add_argument('--bert-no-binary-head', action='store_false',
help='Disable BERT binary head.',
dest='bert_binary_head')
group.add_argument('--num-experts', type=int, default=None,
help='Number of Experts in Switch Transformer (None means no Switch)')
return parser
......@@ -354,6 +424,9 @@ def _add_logging_args(parser):
group.add_argument('--log-memory-to-tensorboard',
action='store_true',
help='Enable memory logging to tensorboard.')
group.add_argument('--log-world-size-to-tensorboard',
action='store_true',
help='Enable world size logging to tensorboard.')
return parser
......@@ -367,6 +440,13 @@ def _add_regularization_args(parser):
help='Dropout probability for hidden state transformer.')
group.add_argument('--weight-decay', type=float, default=0.01,
help='Weight decay coefficient for L2 regularization.')
group.add_argument('--start-weight-decay', type=float,
help='Initial weight decay coefficient for L2 regularization.')
group.add_argument('--end-weight-decay', type=float,
help='End of run weight decay coefficient for L2 regularization.')
group.add_argument('--weight-decay-incr-style', type=str, default='constant',
choices=['constant', 'linear', 'cosine'],
help='Weight decay increment function.')
group.add_argument('--clip-grad', type=float, default=1.0,
help='Gradient clipping based on global L2 norm.')
group.add_argument('--adam-beta1', type=float, default=0.9,
......@@ -413,27 +493,40 @@ def _add_training_args(parser):
' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.')
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
group.add_argument('--recompute-activations', action='store_true',
help='recompute activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--distribute-checkpointed-activations',
group.add_argument('--recompute-granularity', type=str, default=None,
choices=['full', 'selective'],
help='Checkpoint activations to allow for training '
'with larger models, sequences, and batch sizes. '
'It is supported at two granularities 1) full: '
'whole transformer layer is recomputed, '
'2) selective: core attention part of the transformer '
'layer is recomputed.')
group.add_argument('--distribute-saved-activations',
action='store_true',
help='If set, distribute checkpointed activations '
help='If set, distribute recomputed activations '
'across model parallel group.')
group.add_argument('--activations-checkpoint-method', type=str, default=None,
group.add_argument('--recompute-method', type=str, default=None,
choices=['uniform', 'block'],
help='1) uniform: uniformly divide the total number of '
'Transformer layers and checkpoint the input activation of '
'each divided chunk, '
'2) checkpoint the input activations of only a set number of '
'Transformer layers and recompute the input activation of '
'each divided chunk at specified granularity, '
'2) recompute the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'rest without any checkpointing'
'default) do not apply activations checkpoint to any layers')
group.add_argument('--activations-checkpoint-num-layers', type=int, default=1,
'rest without any recomputing at specified granularity'
'default) do not apply activations recompute to any layers')
group.add_argument('--recompute-num-layers', type=int, default=1,
help='1) uniform: the number of Transformer layers in each '
'uniformly divided checkpoint unit, '
'uniformly divided recompute unit, '
'2) block: the number of individual Transformer layers '
'to checkpoint within each pipeline stage.')
'to recompute within each pipeline stage.')
# deprecated
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all '
'training runs. Note that either train-iters or '
......@@ -472,7 +565,20 @@ def _add_training_args(parser):
action='store_true',
help='Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.')
'gradient compuation of a column-linear layer.',
dest='async_tensor_model_parallel_allreduce')
group.add_argument('--no-persist-layer-norm', action='store_true',
help='Disable using persistent fused layer norm kernel. '
'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden '
'size is supported.')
group.add_argument('--sequence-parallel', action='store_true',
help='Enable sequence parallel optimization.')
group.add_argument('--no-gradient-accumulation-fusion',
action='store_false',
help='Disable fusing gradient accumulation to weight '
'gradient computation of linear layers',
dest='gradient_accumulation_fusion')
return parser
......@@ -640,13 +746,16 @@ def _add_distributed_args(parser):
group.add_argument('--use-cpu-initialization', action='store_true',
default=None, help='If set, affine parallel weights '
'initialization uses CPU' )
group.add_argument('--cpu-offload', action='store_true',
default=False, help='Turns on CPU offloading')
group.add_argument('--empty-unused-memory-level', default=0, type=int,
choices=[0, 1, 2],
help='Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.')
group.add_argument('--standalone-embedding-stage', action='store_true',
default=False, help='If set, *input* embedding layer '
'is placed on its own pipeline stage, without any '
'transformer layers. (For T5, this flag currently only '
'affects the encoder embedding.)')
return parser
......@@ -793,16 +902,70 @@ def _add_biencoder_args(parser):
return parser
def _add_vit_args(parser):
group = parser.add_argument_group(title="vit")
def _add_vision_args(parser):
group = parser.add_argument_group(title="vision")
# general vision arguments
group.add_argument('--num-classes', type=int, default=1000,
help='num of classes in vision classificaiton task')
group.add_argument('--img-dim', type=int, default=224,
help='Image size for vision classification task')
group.add_argument('--img-h', type=int, default=224,
help='Image height for vision classification task')
group.add_argument('--img-w', type=int, default=224,
help='Image height for vision classification task')
group.add_argument('--num-channels', type=int, default=3,
help='Number of channels in input image data')
group.add_argument('--patch-dim', type=int, default=16,
help='patch dimension used in vit')
help='patch dimension')
group.add_argument('--classes-fraction', type=float, default=1.0,
help='training with fraction of classes.')
group.add_argument('--data-per-class-fraction', type=float, default=1.0,
help='training with fraction of data per class.')
group.add_argument('--no-data-sharding', action='store_false',
help='Disable data sharding.',
dest='data_sharding')
group.add_argument('--head-lr-mult', type=float, default=1.0,
help='learning rate multiplier for head during finetuning')
# pretraining type and backbone selection`
group.add_argument('--vision-pretraining', action='store_true',
help='flag to indicate vision pretraining')
group.add_argument('--vision-pretraining-type', type=str, default='classify',
choices=['classify', 'inpaint', 'dino'],
help='pretraining objectives')
group.add_argument('--vision-backbone-type', type=str, default='vit',
choices=['vit', 'mit', 'swin'],
help='backbone types types')
group.add_argument('--swin-backbone-type', type=str, default='tiny',
choices=['tiny', 'base', 'h3'],
help='pretraining objectives')
# inpainting arguments
group.add_argument('--mask-type', type=str, default='random',
choices=['random', 'row'],
help='mask types')
group.add_argument('--mask-factor', type=float, default=1.0,
help='mask size scaling parameter')
# dino arguments
group.add_argument('--iter-per-epoch', type=int, default=1250,
help='iterations per epoch')
group.add_argument('--dino-local-img-size', type=int, default=96,
help='Image size for vision classification task')
group.add_argument('--dino-local-crops-number', type=int, default=10,
help='Number of local crops')
group.add_argument('--dino-head-hidden-size', type=int, default=2048,
help='Hidden dimension size in dino head')
group.add_argument('--dino-bottleneck-size', type=int, default=256,
help='Bottle neck dimension in dino head ')
group.add_argument('--dino-freeze-last-layer', type=float, default=1,
help='Freezing last layer weights')
group.add_argument('--dino-norm-last-layer', action='store_true',
help='Disable Norm in last layer.')
group.add_argument('--dino-warmup-teacher-temp', type=float, default=0.04,
help='warump teacher temperature')
group.add_argument('--dino-teacher-temp', type=float, default=0.07,
help='teacher temperature')
group.add_argument('--dino-warmup-teacher-temp-epochs', type=int, default=30,
help='warmup teacher temperaure epochs')
return parser
......@@ -12,19 +12,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import datetime
import os
import random
from typing import Optional, Union, List
from typing import Optional, Union, List, Tuple, Callable, Dict
import numpy
import torch
import torch.nn as nn
from apex import transformer
from apex.transformer.tensor_parallel import(
ColumnParallelLinear,
RowParallelLinear,
scatter_to_sequence_parallel_region,
)
from apex.transformer.pipeline_parallel.utils import (
average_losses_across_data_parallel_group,
)
from apex.transformer.pipeline_parallel.schedules.common import (
Batch,
)
from apex.transformer.testing import global_vars
......@@ -45,7 +54,10 @@ class MyLayer(nn.Module):
class MyModel(nn.Module):
def __init__(
self, hidden_size: int, pre_process: bool = False, post_process: bool = False
self,
hidden_size: int, pre_process: bool = False, post_process: bool = False,
*,
add_encoder: bool = False, add_decoder: bool = False,
) -> None:
super().__init__()
self.pre_process = pre_process
......@@ -68,8 +80,105 @@ class MyModel(nn.Module):
return self.layer(self.input_tensor)
def model_provider_func(hidden_size, pre_process, post_process) -> MyModel:
return MyModel(hidden_size, pre_process, post_process)
class ToyParallelMLP(nn.Module):
def __init__(
self,
hidden_size: int, pre_process: bool = False, post_process: bool = False,
*,
sequence_parallel_enabled: bool = False,
# TODO(mkozuki): Support these two?
add_encoder: bool = False, add_decoder: bool = False,
) -> None:
super().__init__()
self.pre_process = pre_process
self.post_process = post_process
self.sequence_parallel_enabled = sequence_parallel_enabled
ffn_hidden_size = 4 * hidden_size
self.dense_h_to_4h = ColumnParallelLinear(
hidden_size,
ffn_hidden_size,
gather_output=False,
# init_method=init_method,
skip_bias_add=True,
# use_cpu_initialization=use_cpu_initialization,
bias=True,
sequence_parallel_enabled=sequence_parallel_enabled,
no_async_tensor_model_parallel_allreduce=True,
)
self.dense_4h_to_h = RowParallelLinear(
ffn_hidden_size,
hidden_size,
input_is_parallel=True,
# init_method=output_layer_init_method,
skip_bias_add=False,
# use_cpu_initialization=use_cpu_initialization,
bias=True,
sequence_parallel_enabled=sequence_parallel_enabled,
)
self.activation_func = torch.nn.GELU()
def set_input_tensor(
self,
input_tensor: Union[torch.Tensor, List[torch.Tensor]],
) -> None:
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
self.input_tensor = input_tensor[0]
def forward(
self,
x: Optional[torch.Tensor],
) -> torch.Tensor:
"""Forward of Simplified ParallelMLP.
Args:
x: :obj:`None` if pipeline rank != pippeline first rank. When :obj:`None`,
`self.input_tensor` is taken care of by `forward_step` defined in
apex/transformer/pipeline_parallel/schedules/common.py
"""
# [s, b, h]
if self.input_tensor is None:
input = x
else:
input = self.input_tensor
intermediate_parallel, bias_parallel = self.dense_h_to_4h(input)
if bias_parallel is not None:
intermediate_parallel += bias_parallel
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output
def model_provider_func(
hidden_size: int,
pre_process: bool,
post_process: bool,
*,
add_encoder: bool = False,
add_decoder: bool = False) -> MyModel:
return MyModel(hidden_size, pre_process, post_process, add_encoder=add_encoder, add_decoder=add_decoder)
def mlp_provider_func(
hidden_size: int,
pre_process: bool,
post_process: bool,
*,
add_encoder: bool = False,
add_decoder: bool = False,
sequence_parallel_enabled: bool = False,
) -> ToyParallelMLP:
return ToyParallelMLP(
hidden_size,
pre_process,
post_process,
add_encoder=add_encoder,
add_decoder=add_decoder,
sequence_parallel_enabled=sequence_parallel_enabled,
)
def process_batch(batch):
......@@ -94,6 +203,33 @@ def fwd_step_func(batch, model):
return y, loss_func
@dataclass(frozen=True)
class ToyParallelMLPFwdBwdStepFunc:
sequence_parallel_enabled: bool
def __call__(
self,
batch: Batch,
model: torch.nn.Module,
) -> Tuple[torch.Tensor, Callable[[torch.Tensor], Tuple[torch.Tensor, Dict[str, torch.Tensor]]]]:
x = batch[0] if isinstance(batch, list) else batch
if isinstance(x, torch.Tensor):
x = x.transpose(0, 1).contiguous()
if self.sequence_parallel_enabled:
x = scatter_to_sequence_parallel_region(x)
y = model(x)
# note (mkozuki): I don't think this function is nice but I do think this is enough for now
# just to check the sanity of ported pipeline functions.
def loss_func(x):
loss = torch.sum(x)
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"avg": averaged_loss}
return y, loss_func
class IdentityLayer(torch.nn.Module):
def __init__(self, size, scale=1.0):
super(IdentityLayer, self).__init__()
......
import torch
import contextlib
from apex.normalization import FusedLayerNorm as LayerNorm
import torch
from apex.transformer import tensor_parallel
from apex.transformer.enums import AttnMaskType
from apex.transformer.enums import ModelType
from apex.transformer.layers import FusedLayerNorm as LayerNorm
from apex.transformer.testing.global_vars import get_args
from .standalone_gpt import get_language_model, get_linear_layer, init_method_normal, parallel_lm_logits, scaled_init_method_normal
from .standalone_gpt import MegatronModule
from apex.transformer.testing.standalone_transformer_lm import (
MegatronModule,
get_language_model,
get_linear_layer,
init_method_normal,
scaled_init_method_normal,
parallel_lm_logits,
)
def bert_extended_attention_mask(attention_mask):
# We create a 3D attention mask from a 2D tensor mask.
......@@ -23,6 +33,7 @@ def bert_extended_attention_mask(attention_mask):
return extended_attention_mask
def bert_position_ids(token_ids):
# Create position ids
seq_length = token_ids.size(1)
......@@ -32,6 +43,7 @@ def bert_position_ids(token_ids):
return position_ids
class BertLMHead(MegatronModule):
"""Masked LM head for Bert
......@@ -56,13 +68,18 @@ class BertLMHead(MegatronModule):
self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
setattr(self.dense.weight, 'sequence_parallel', args.sequence_parallel)
setattr(self.dense.bias, 'sequence_parallel', args.sequence_parallel)
self.layernorm = LayerNorm(
hidden_size, eps=layernorm_epsilon, sequence_parallel_enabled=args.sequence_parallel)
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
self.gelu = openai_gelu
elif args.onnx_safe:
self.gelu = erf_gelu
def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states)
hidden_states = self.gelu(hidden_states)
......@@ -73,6 +90,7 @@ class BertLMHead(MegatronModule):
bias=self.bias)
return output
def post_language_model_processing(lm_output, pooled_output,
lm_head, binary_head,
lm_labels,
......@@ -87,8 +105,12 @@ def post_language_model_processing(lm_output, pooled_output,
binary_logits = binary_head(pooled_output)
if lm_labels is None:
return lm_logits, binary_logits
# [s b h] => [b s h]
return lm_logits.transpose(0, 1).contiguous(), binary_logits
else:
# [b s] => [s b]
lm_labels = lm_labels.transpose(0, 1).contiguous()
# lm_logits: [s b h] lm_labels: [s b]
if fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
......@@ -116,7 +138,7 @@ class BertModel(MegatronModule):
self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
self.cpu_offload = cpu_offload
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
......@@ -142,13 +164,17 @@ class BertModel(MegatronModule):
init_method)
self._binary_head_key = 'binary_head'
self.forward_context = contextlib.nullcontext
if cpu_offload:
self.forward_context = torch.autograd.graph.save_on_cpu
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, bert_model_input, attention_mask,
tokentype_ids=None, lm_labels=None):
with torch.autograd.graph.save_on_cpu() if self.cpu_offload else contextlib.nullcontext():
with self.forward_context():
extended_attention_mask = bert_extended_attention_mask(attention_mask)
input_ids = bert_model_input
position_ids = bert_position_ids(input_ids)
......@@ -174,7 +200,7 @@ class BertModel(MegatronModule):
else:
return lm_output
# NOTE(mkozuki): This method is not maintained as apex only tests forward_backward with best effort.
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
......@@ -197,6 +223,7 @@ class BertModel(MegatronModule):
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_
# NOTE(mkozuki): This method is not maintained as apex only tests forward_backward with best effort.
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......@@ -213,6 +240,16 @@ class BertModel(MegatronModule):
self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict)
def bert_model_provider(pre_process=True, post_process=True, cpu_offload=False):
model = BertModel(num_tokentypes=0, add_binary_head=False, pre_process=pre_process, post_process=post_process, cpu_offload=cpu_offload)
args = get_args()
num_tokentypes = 2 if args.bert_binary_head else 0
model = BertModel(
num_tokentypes=num_tokentypes,
add_binary_head=args.bert_binary_head,
parallel_output=True,
pre_process=pre_process,
post_process=post_process,
cpu_offload=cpu_offload,
)
return model
This diff is collapsed.
This diff is collapsed.
......@@ -8,11 +8,13 @@ else:
HAS_TORCH_UCC = True
print("Use UCC as backend of Pipeline Parallel ProcessGroups")
from apex.transformer.enums import ModelType
from apex.transformer import tensor_parallel
from apex.transformer import parallel_state
from apex.transformer.log_util import set_logging_level
from apex.transformer.tensor_parallel import vocab_parallel_cross_entropy
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import unwrap_model
from apex.transformer.pipeline_parallel.utils import (
average_losses_across_data_parallel_group,
)
......@@ -148,8 +150,24 @@ def train(
batch = generate_fancy_data_labels(sequence_len, batch_size)
optim.zero_grad()
forward_backward_func(
fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape, async_comm=async_comm,
fwd_step_func,
batch,
model,
forward_only=False,
tensor_shape=tensor_shape,
async_comm=async_comm,
sequence_parallel_enabled=global_vars.get_args().sequence_parallel,
)
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if parallel_state.get_tensor_model_parallel_world_size() > 1 and global_vars.get_args().sequence_parallel:
for model_module in model:
unwrapped_model = unwrap_model(model_module)
for param in unwrapped_model.parameters():
if getattr(param, 'sequence_parallel_enabled', False):
grad = param.grad
torch.distributed.all_reduce(grad, group=parallel_state.get_tensor_model_parallel_group())
optim.step()
......@@ -169,13 +187,15 @@ if __name__ == "__main__":
init = True
try:
for virtual_pipeline_model_parallel_size in (2, None):
async_comm = virtual_pipeline_model_parallel_size is None
args = global_vars.get_args()
async_comm = not args.sequence_parallel and virtual_pipeline_model_parallel_size is None
data_idx = 0
ONCE = False
if init:
init = False
args = global_vars.get_args()
args.padded_vocab_size = 128 # needed in standalone gpt
args.model_type = ModelType.encoder_or_decoder
batch_size = args.global_batch_size
micro_batch_size = args.micro_batch_size
setup_microbatch_calculator(
......@@ -201,7 +221,7 @@ if __name__ == "__main__":
tensor_parallel.random.model_parallel_cuda_manual_seed(0)
model = build_model(
bert_model_provider,
wrap_with_ddp=True,
wrap_with_ddp=parallel_state.get_data_parallel_world_size() > 1,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
cpu_offload=args.cpu_offload,
)
......
......@@ -12,8 +12,10 @@ else:
print("Use UCC as backend of Pipeline Parallel ProcessGroups")
from apex.transformer import parallel_state
from apex.transformer.enums import ModelType
from apex.transformer.tensor_parallel import model_parallel_cuda_manual_seed
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import unwrap_model
from apex.transformer.pipeline_parallel.utils import (
average_losses_across_data_parallel_group,
)
......@@ -132,10 +134,25 @@ def train(model, optim, pipeline_model_parallel_size, async_comm):
print("finished making batch...")
optim.zero_grad()
fwd_bwd_func(
fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape, async_comm=async_comm
fwd_step_func,
batch,
model,
forward_only=False,
tensor_shape=tensor_shape,
async_comm=async_comm,
sequence_parallel_enabled=args.sequence_parallel,
)
if torch.distributed.get_rank() == 0:
print("finished forward step")
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if parallel_state.get_tensor_model_parallel_world_size() > 1 and global_vars.get_args().sequence_parallel:
for model_module in model:
unwrapped_model = unwrap_model(model_module)
for param in unwrapped_model.parameters():
if getattr(param, 'sequence_parallel_enabled', False):
grad = param.grad
torch.distributed.all_reduce(grad, group=parallel_state.get_tensor_model_parallel_group())
optim.step()
if torch.distributed.get_rank() == 0:
print("finished iter", i)
......@@ -145,16 +162,17 @@ def train(model, optim, pipeline_model_parallel_size, async_comm):
if __name__ == "__main__":
init = True
for async_comm in (False, True):
global_vars.set_global_variables()
for async_comm in (False,) if global_vars.get_args().sequence_parallel else (False, True):
global fancy_data
global effective_length
if init:
init = False
global_vars.set_global_variables()
fancy_data = download_fancy_data()
args = global_vars.get_args()
args.model_type = ModelType.encoder_or_decoder
effective_length = fancy_data.size(0) // args.seq_length
effective_length = fancy_data.size(0) - args.seq_length
......@@ -189,7 +207,7 @@ if __name__ == "__main__":
model_parallel_cuda_manual_seed(0)
model = build_model(
gpt_model_provider,
wrap_with_ddp=True,
wrap_with_ddp=parallel_state.get_data_parallel_world_size() > 1,
virtual_pipeline_model_parallel_size=None,
cpu_offload=args.cpu_offload,
)
......
This diff is collapsed.
......@@ -3,13 +3,13 @@ import logging
import torch
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import mappings
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("apex").setLevel(logging.WARNING)
......@@ -49,7 +49,7 @@ class MappingTestBase:
for rank in range(tensor_model_paralell_world_size)
]
x = torch.cat(tensors, 1)
out = mappings._split(x)
out = mappings._split_along_last_dim(x)
self.assertTrue(
torch.equal(
out, tensors[parallel_state.get_tensor_model_parallel_rank()]
......@@ -68,7 +68,7 @@ class MappingTestBase:
tensor_model_parallel_size_=tensor_model_paralell_world_size
)
device = f"cuda:{self.rank}"
gathered = mappings._gather(
gathered = mappings._gather_along_last_dim(
torch.tensor(
[parallel_state.get_tensor_model_parallel_rank()], device=device
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment