Commit 3aca1415 authored by liangjing's avatar liangjing
Browse files

Merge branch 'megatron-lm_dtk24.04' into 'main'

Megatron lm dtk24.04

See merge request !1
parents 0024a5c6 1005e9d3
Pipeline #1806 passed with stage
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
MAJOR = 0
MINOR = 3
PATCH = 0
PRE_RELEASE = ''
# Use the following formatting: (major, minor, patch, pre-release)
VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE)
__shortversion__ = '.'.join(map(str, VERSION[:3]))
__version__ = '.'.join(map(str, VERSION[:3])) + ''.join(VERSION[3:])
__package_name__ = 'megatron_core'
__contact_names__ = 'NVIDIA'
__contact_emails__ = 'nemo-toolkit@nvidia.com' # use NeMo Email
__homepage__ = (
'https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/' # use NeMo homepage
)
__repository_url__ = 'https://github.com/NVIDIA/Megatron-LM/megatron/core'
__download_url__ = 'https://github.com/NVIDIA/Megatron-LM/releases'
__description__ = (
'Megatron Core - a library for efficient and scalable training of transformer based models'
)
__license__ = 'BSD-3'
__keywords__ = (
'deep learning, machine learning, gpu, NLP, NLU, language, transformer, nvidia, pytorch, torch'
)
......@@ -2,9 +2,11 @@
"""Model and data parallel groups."""
import torch
import os
from typing import Optional
import torch
from .utils import GlobalMemoryBuffer
# Intra-layer model parallel group that the current rank belongs to.
......@@ -19,6 +21,9 @@ _EMBEDDING_GROUP = None
_POSITION_EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
_DATA_PARALLEL_GROUP_GLOO = None
# FP8 amax reduction group.
_AMAX_REDUCTION_GROUP = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
......@@ -53,9 +58,10 @@ def initialize_model_parallel(
pipeline_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: Optional[int] = None,
pipeline_model_parallel_split_rank: Optional[int] = None,
use_fp8: bool = False,
use_sharp: bool = False,
) -> None:
"""
Initialize model data parallel groups.
"""Initialize model data parallel groups.
Arguments:
tensor_model_parallel_size (int, default = 1):
......@@ -93,6 +99,17 @@ def initialize_model_parallel(
pipeline_model_parallel_split_rank is 3, then ranks 0-2
will be the encoder and ranks 3-7 will be the decoder.
use_fp8 (bool, default = False):
Construct GPU groups needed for FP8 training, namely for
amax reduction across the product of the data-parallel and
tensor-parallel groups.
use_sharp (bool, default = False):
Set the use of SHARP for the collective communications of
data-parallel process groups. When `True`, run barrier
within each data-parallel process group, which specifies
the SHARP application target groups.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
......@@ -108,6 +125,7 @@ def initialize_model_parallel(
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
......@@ -119,17 +137,19 @@ def initialize_model_parallel(
f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})"
)
data_parallel_size: int = world_size // (tensor_model_parallel_size *
pipeline_model_parallel_size)
data_parallel_size: int = world_size // (
tensor_model_parallel_size * pipeline_model_parallel_size
)
num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
num_data_parallel_groups: int = world_size // data_parallel_size
if virtual_pipeline_model_parallel_size is not None:
if not pipeline_model_parallel_size > 2:
raise RuntimeError("pipeline-model-parallel size should be greater than 2 with "
"interleaved schedule")
raise RuntimeError(
"pipeline-model-parallel size should be greater than 2 with interleaved schedule"
)
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
......@@ -143,6 +163,7 @@ def initialize_model_parallel(
# Build the data-parallel groups.
global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_GROUP_GLOO
global _DATA_PARALLEL_GLOBAL_RANKS
assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'
all_data_parallel_group_ranks = []
......@@ -153,27 +174,50 @@ def initialize_model_parallel(
ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
all_data_parallel_group_ranks.append(list(ranks))
group = torch.distributed.new_group(ranks)
group_gloo = torch.distributed.new_group(ranks, backend="gloo")
if rank in ranks:
_DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_GROUP_GLOO = group_gloo
_DATA_PARALLEL_GLOBAL_RANKS = ranks
# Apply SHARP to DP process groups
if use_sharp:
if rank == 0:
print(
"The number of process groups to use SHARP with depends on the type "
"of the network switch. Nvidia QM1 switch supports SAHRP up to 8 "
"process groups and QM2 supports up to 256 process groups. We apply "
"SHARP to the communications of the data-parallel domain. If the "
"number of data-parallel process groups is larger than the max "
"process groups that the network switch supports, the communication "
"will fall back to non-SHARP operators. To enable SHARP, "
"`#SBATCH_NETWORK=sharp` should be set in the sbatch script."
)
torch.distributed.barrier(
group=get_data_parallel_group(), device_ids=[torch.cuda.current_device()]
)
# Set `NCCL_SHARP_DISABLE=1` to restrict SHARP application to DP process groups
os.environ["NCCL_SHARP_DISABLE"] = "1"
# Build the model-parallel groups.
global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'
for i in range(data_parallel_size):
ranks = [data_parallel_group_ranks[i]
for data_parallel_group_ranks in all_data_parallel_group_ranks]
ranks = [
data_parallel_group_ranks[i]
for data_parallel_group_ranks in all_data_parallel_group_ranks
]
group = torch.distributed.new_group(ranks)
if rank in ranks:
_MODEL_PARALLEL_GROUP = group
# Build the tensor model-parallel groups.
global _TENSOR_MODEL_PARALLEL_GROUP
assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
'tensor model parallel group is already initialized'
assert (
_TENSOR_MODEL_PARALLEL_GROUP is None
), 'tensor model parallel group is already initialized'
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size)
ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group
......@@ -182,15 +226,15 @@ def initialize_model_parallel(
# (first and last rank in each pipeline model-parallel group).
global _PIPELINE_MODEL_PARALLEL_GROUP
global _PIPELINE_GLOBAL_RANKS
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
'pipeline model parallel group is already initialized'
assert (
_PIPELINE_MODEL_PARALLEL_GROUP is None
), 'pipeline model parallel group is already initialized'
global _EMBEDDING_GROUP
global _EMBEDDING_GLOBAL_RANKS
assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'
global _POSITION_EMBEDDING_GROUP
global _POSITION_EMBEDDING_GLOBAL_RANKS
assert _POSITION_EMBEDDING_GROUP is None, \
'position embedding group is already initialized'
assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized'
for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks)
......@@ -204,12 +248,13 @@ def initialize_model_parallel(
position_embedding_ranks = [ranks[0]]
if pipeline_model_parallel_split_rank is not None:
if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:
embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank],
ranks[-1]]
embedding_ranks = [
ranks[0],
ranks[pipeline_model_parallel_split_rank],
ranks[-1],
]
if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:
position_embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank]]
position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]]
else:
embedding_ranks = ranks
position_embedding_ranks = ranks
......@@ -226,6 +271,20 @@ def initialize_model_parallel(
if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
# Build the FP8 groups.
global _AMAX_REDUCTION_GROUP
assert _AMAX_REDUCTION_GROUP is None, 'FP8 amax reduction group is already initialized'
if use_fp8:
amax_group_size: int = tensor_model_parallel_size * data_parallel_size
num_amax_groups: int = world_size // amax_group_size
for i in range(num_amax_groups):
start_rank = i * amax_group_size
end_rank = (i + 1) * amax_group_size
ranks = range(start_rank, end_rank)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_AMAX_REDUCTION_GROUP = group
# Initialize global memory buffer
# This isn't really "parallel state" but there isn't another good place to
# put this. If we end up with a more generic initialization of megatron-core
......@@ -240,55 +299,68 @@ def is_unitialized():
def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
_PIPELINE_MODEL_PARALLEL_GROUP is None or \
_DATA_PARALLEL_GROUP is None:
if (
_TENSOR_MODEL_PARALLEL_GROUP is None
or _PIPELINE_MODEL_PARALLEL_GROUP is None
or _DATA_PARALLEL_GROUP is None
):
return False
return True
def get_model_parallel_group():
"""Get the model parallel group the caller rank belongs to."""
assert _MODEL_PARALLEL_GROUP is not None, \
'model parallel group is not initialized'
assert _MODEL_PARALLEL_GROUP is not None, 'model parallel group is not initialized'
return _MODEL_PARALLEL_GROUP
def get_tensor_model_parallel_group():
def get_tensor_model_parallel_group(check_initialized=True):
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
'intra_layer_model parallel group is not initialized'
if check_initialized:
assert (
_TENSOR_MODEL_PARALLEL_GROUP is not None
), 'tensor model parallel group is not initialized'
return _TENSOR_MODEL_PARALLEL_GROUP
def get_pipeline_model_parallel_group():
"""Get the pipeline model parallel group the caller rank belongs to."""
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \
'pipeline_model parallel group is not initialized'
assert (
_PIPELINE_MODEL_PARALLEL_GROUP is not None
), 'pipeline_model parallel group is not initialized'
return _PIPELINE_MODEL_PARALLEL_GROUP
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, \
'data parallel group is not initialized'
assert _DATA_PARALLEL_GROUP is not None, 'data parallel group is not initialized'
return _DATA_PARALLEL_GROUP
def get_data_parallel_group_gloo():
"""Get the data parallel group-gloo the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP_GLOO is not None, 'data parallel group-gloo is not initialized'
return _DATA_PARALLEL_GROUP_GLOO
def get_embedding_group():
"""Get the embedding group the caller rank belongs to."""
assert _EMBEDDING_GROUP is not None, \
'embedding group is not initialized'
assert _EMBEDDING_GROUP is not None, 'embedding group is not initialized'
return _EMBEDDING_GROUP
def get_position_embedding_group():
"""Get the position embedding group the caller rank belongs to."""
assert _POSITION_EMBEDDING_GROUP is not None, \
'position embedding group is not initialized'
assert _POSITION_EMBEDDING_GROUP is not None, 'position embedding group is not initialized'
return _POSITION_EMBEDDING_GROUP
def get_amax_reduction_group():
"""Get the FP8 amax reduction group the caller rank belongs to."""
assert _AMAX_REDUCTION_GROUP is not None, 'FP8 amax reduction group is not initialized'
return _AMAX_REDUCTION_GROUP
def set_tensor_model_parallel_world_size(world_size):
"""Set the tensor model parallel size"""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
......@@ -301,6 +373,12 @@ def set_pipeline_model_parallel_world_size(world_size):
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
def set_virtual_pipeline_model_parallel_world_size(world_size):
"""Set the pipeline model parallel size"""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
......@@ -360,8 +438,10 @@ def get_pipeline_model_parallel_split_rank():
def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
if get_virtual_pipeline_model_parallel_world_size() is not None and \
get_virtual_pipeline_model_parallel_rank() != 0:
if (
get_virtual_pipeline_model_parallel_world_size() is not None
and get_virtual_pipeline_model_parallel_rank() != 0
):
return False
return get_pipeline_model_parallel_rank() == 0
......@@ -369,14 +449,14 @@ def is_pipeline_first_stage(ignore_virtual=False):
def is_pipeline_last_stage(ignore_virtual=False):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
virtual_pipeline_model_parallel_world_size = \
virtual_pipeline_model_parallel_world_size = (
get_virtual_pipeline_model_parallel_world_size()
if virtual_pipeline_model_parallel_world_size is not None and \
get_virtual_pipeline_model_parallel_rank() != (
virtual_pipeline_model_parallel_world_size - 1):
)
if virtual_pipeline_model_parallel_world_size is not None and get_virtual_pipeline_model_parallel_rank() != (
virtual_pipeline_model_parallel_world_size - 1
):
return False
return get_pipeline_model_parallel_rank() == (
get_pipeline_model_parallel_world_size() - 1)
return get_pipeline_model_parallel_rank() == (get_pipeline_model_parallel_world_size() - 1)
def is_rank_in_embedding_group(ignore_virtual=False):
......@@ -437,8 +517,7 @@ def is_pipeline_stage_at_split():
stage executes encoder block for a model with both encoder and
decoder."""
rank = get_pipeline_model_parallel_rank()
return is_pipeline_stage_before_split(rank) and \
is_pipeline_stage_after_split(rank+1)
return is_pipeline_stage_before_split(rank) and is_pipeline_stage_after_split(rank + 1)
def get_virtual_pipeline_model_parallel_rank():
......@@ -459,12 +538,6 @@ def get_virtual_pipeline_model_parallel_world_size():
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
def set_virtual_pipeline_model_parallel_world_size(world_size):
"""Set the virtual pipeline-parallel world size"""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
......@@ -476,31 +549,28 @@ def get_tensor_model_parallel_src_rank():
def get_data_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the data parallel group."""
assert _DATA_PARALLEL_GLOBAL_RANKS is not None, \
"Data parallel group is not initialized"
assert _DATA_PARALLEL_GLOBAL_RANKS is not None, "Data parallel group is not initialized"
return _DATA_PARALLEL_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_first_rank():
"""Return the global rank of the first process in the pipeline for the
current tensor parallel group"""
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
return _PIPELINE_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_last_rank():
"""Return the global rank of the last process in the pipeline for the
current tensor parallel group"""
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_next_rank():
"""Return the global rank that follows the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
......@@ -508,8 +578,7 @@ def get_pipeline_model_parallel_next_rank():
def get_pipeline_model_parallel_prev_rank():
"""Return the global rank that preceeds the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
......@@ -517,12 +586,19 @@ def get_pipeline_model_parallel_prev_rank():
def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
return torch.distributed.get_world_size(group=get_data_parallel_group())
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_world_size(group=get_data_parallel_group())
else:
return 0
def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return torch.distributed.get_rank(group=get_data_parallel_group())
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_rank(group=get_data_parallel_group())
else:
return 0
def _set_global_memory_buffer():
"""Initialize global buffer"""
......@@ -530,12 +606,19 @@ def _set_global_memory_buffer():
assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized'
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
def get_global_memory_buffer():
"""Return the global GlobalMemoryBuffer object"""
assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
return _GLOBAL_MEMORY_BUFFER
def destroy_global_memory_buffer():
"""Sets the global memory buffer to None"""
global _GLOBAL_MEMORY_BUFFER
_GLOBAL_MEMORY_BUFFER = None
def destroy_model_parallel():
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
......@@ -550,6 +633,8 @@ def destroy_model_parallel():
_EMBEDDING_GROUP = None
global _POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP = None
global _AMAX_REDUCTION_GROUP
_AMAX_REDUCTION_GROUP = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from functools import reduce
import operator
from typing import Optional, List, Union, Callable, Tuple
from functools import reduce
from typing import Callable, List, Optional, Tuple, Union
import torch
from megatron import core
from megatron.core import ModelParallelConfig
from megatron.core.parallel_state import (
get_pipeline_model_parallel_group,
get_pipeline_model_parallel_prev_rank,
get_pipeline_model_parallel_next_rank,
get_pipeline_model_parallel_prev_rank,
get_pipeline_model_parallel_rank,
)
# Types
Shape = Union[List[int], torch.Size]
def _communicate_shapes(tensor_send_next, tensor_send_prev,
recv_prev, recv_next,
use_ring_exchange_p2p):
def _communicate_shapes(tensor_send_next, tensor_send_prev, recv_prev, recv_next, config):
"""Communicate tensor shapes between stages. Used to communicate
tensor shapes before the actual tensor communication happens.
This is required when the sequence lengths across micro batches
......@@ -42,49 +43,59 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev,
send_prev_shape_tensor = None
send_next_shape_tensor = None
if recv_prev:
recv_prev_shape_tensor = torch.empty((3),
device=torch.cuda.current_device(),
dtype=torch.int64)
recv_prev_shape_tensor = torch.empty(
(3), device=torch.cuda.current_device(), dtype=torch.int64
)
if recv_next:
recv_next_shape_tensor = torch.empty((3),
device=torch.cuda.current_device(),
dtype=torch.int64)
recv_next_shape_tensor = torch.empty(
(3), device=torch.cuda.current_device(), dtype=torch.int64
)
if tensor_send_prev is not None:
send_prev_shape_tensor = torch.tensor(tensor_send_prev.size(),
device=torch.cuda.current_device(),
dtype=torch.int64)
send_prev_shape_tensor = torch.tensor(
tensor_send_prev.size(), device=torch.cuda.current_device(), dtype=torch.int64
)
if tensor_send_next is not None:
send_next_shape_tensor = torch.tensor(tensor_send_next.size(),
device=torch.cuda.current_device(),
dtype=torch.int64)
if use_ring_exchange_p2p:
torch.distributed.ring_exchange(tensor_send_prev=send_prev_shape_tensor,
tensor_recv_prev=recv_prev_shape_tensor,
tensor_send_next=send_next_shape_tensor,
tensor_recv_next=recv_next_shape_tensor,
group=mpu.get_pipeline_model_parallel_group())
send_next_shape_tensor = torch.tensor(
tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64
)
if config.use_ring_exchange_p2p:
torch.distributed.ring_exchange(
tensor_send_prev=send_prev_shape_tensor,
tensor_recv_prev=recv_prev_shape_tensor,
tensor_send_next=send_next_shape_tensor,
tensor_recv_next=recv_next_shape_tensor,
group=get_pipeline_model_parallel_group(),
)
else:
ops = []
if send_prev_shape_tensor is not None:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend, send_prev_shape_tensor,
mpu.get_pipeline_model_parallel_prev_rank())
torch.distributed.isend,
send_prev_shape_tensor,
get_pipeline_model_parallel_prev_rank(),
)
ops.append(send_prev_op)
if recv_prev_shape_tensor is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, recv_prev_shape_tensor,
mpu.get_pipeline_model_parallel_prev_rank())
torch.distributed.irecv,
recv_prev_shape_tensor,
get_pipeline_model_parallel_prev_rank(),
)
ops.append(recv_prev_op)
if send_next_shape_tensor is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, send_next_shape_tensor,
mpu.get_pipeline_model_parallel_next_rank())
torch.distributed.isend,
send_next_shape_tensor,
get_pipeline_model_parallel_next_rank(),
)
ops.append(send_next_op)
if recv_next_shape_tensor is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv, recv_next_shape_tensor,
mpu.get_pipeline_model_parallel_next_rank())
torch.distributed.irecv,
recv_next_shape_tensor,
get_pipeline_model_parallel_next_rank(),
)
ops.append(recv_next_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
......@@ -106,15 +117,126 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev,
return recv_prev_shape, recv_next_shape
def _communicate(*, tensor_send_next: Optional[torch.Tensor],
tensor_send_prev: Optional[torch.Tensor],
recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
dtype: Optional[torch.dtype],
variable_seq_lengths: bool = False,
use_ring_exchange_p2p: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
def _batched_p2p_ops(
*,
tensor_send_prev: Optional[torch.Tensor],
tensor_recv_prev: Optional[torch.Tensor],
tensor_send_next: Optional[torch.Tensor],
tensor_recv_next: Optional[torch.Tensor],
group: torch.distributed.ProcessGroup
):
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor_send_prev,
get_pipeline_model_parallel_prev_rank(),
group,
)
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_prev,
get_pipeline_model_parallel_prev_rank(),
group,
)
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor_send_next,
get_pipeline_model_parallel_next_rank(),
group,
)
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_next,
get_pipeline_model_parallel_next_rank(),
group,
)
ops.append(recv_next_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
else:
reqs = []
return reqs
def _p2p_ops(
*,
tensor_send_prev: Optional[torch.Tensor],
tensor_recv_prev: Optional[torch.Tensor],
tensor_send_next: Optional[torch.Tensor],
tensor_recv_next: Optional[torch.Tensor],
group: torch.distributed.ProcessGroup
):
reqs = []
rank = get_pipeline_model_parallel_rank()
if get_pipeline_model_parallel_rank() % 2 == 0:
if tensor_send_next is not None:
send_next_req = torch.distributed.isend(
tensor=tensor_send_next, dst=get_pipeline_model_parallel_next_rank(), group=group,
)
reqs.append(send_next_req)
if tensor_recv_prev is not None:
recv_prev_req = torch.distributed.irecv(
tensor=tensor_recv_prev, src=get_pipeline_model_parallel_prev_rank(), group=group,
)
reqs.append(recv_prev_req)
if tensor_send_prev is not None:
send_prev_req = torch.distributed.isend(
tensor=tensor_send_prev, dst=get_pipeline_model_parallel_prev_rank(), group=group,
)
reqs.append(send_prev_req)
if tensor_recv_next is not None:
recv_next_req = torch.distributed.irecv(
tensor=tensor_recv_next, src=get_pipeline_model_parallel_next_rank(), group=group,
)
reqs.append(recv_next_req)
else:
if tensor_recv_prev is not None:
recv_prev_req = torch.distributed.irecv(
tensor=tensor_recv_prev, src=get_pipeline_model_parallel_prev_rank(), group=group,
)
reqs.append(recv_prev_req)
if tensor_send_next is not None:
send_next_req = torch.distributed.isend(
tensor=tensor_send_next, dst=get_pipeline_model_parallel_next_rank(), group=group,
)
reqs.append(send_next_req)
if tensor_recv_next is not None:
recv_next_req = torch.distributed.irecv(
tensor=tensor_recv_next, src=get_pipeline_model_parallel_next_rank(), group=group,
)
reqs.append(recv_next_req)
if tensor_send_prev is not None:
send_prev_req = torch.distributed.isend(
tensor=tensor_send_prev, dst=get_pipeline_model_parallel_prev_rank(), group=group,
)
reqs.append(send_prev_req)
return reqs
def _communicate(
*,
tensor_send_next: Optional[torch.Tensor],
tensor_send_prev: Optional[torch.Tensor],
recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
config: ModelParallelConfig,
wait_on_reqs: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
......@@ -136,23 +258,9 @@ def _communicate(*, tensor_send_next: Optional[torch.Tensor],
tensors sent and received in a single function call are
the same shape).
dtype (torch.dtype, required if either recv_{prev,next} is True):
this must be the type of the tensors that will be
received, will typically be params_dtype, but in the case
of fp32 residual connections might be torch.float.
variable_seq_lengths (bool, optional, default=False):
Support for variable sequence lengths across
microbatches. Setting this communicates the size of
tensors during pipeline parallelism communication, because
of this extra overhead it should only be set if the
sequence length is not constant during training.
use_ring_exchange_p2p (bool, optional, default = False):
Use custom ring_exchange kernel instead of
torch.distributed.batch_isend_irecv(). Requires custom
built torch with torch.distributed.ring_exchange.
wait_on_reqs (boolean, optional, default=False):
For non-batched p2p communication, wait on each request
before returning.
Returns:
tuple containing
......@@ -167,84 +275,79 @@ def _communicate(*, tensor_send_next: Optional[torch.Tensor],
tensor_recv_prev = None
tensor_recv_next = None
if not variable_seq_lengths:
if not config.variable_seq_lengths:
recv_prev_shape = tensor_shape
recv_next_shape = tensor_shape
else:
recv_prev_shape, recv_next_shape = \
_communicate_shapes(tensor_send_next,
tensor_send_prev,
recv_prev,
recv_next)
recv_prev_shape, recv_next_shape = _communicate_shapes(
tensor_send_next, tensor_send_prev, recv_prev, recv_next, config
)
if recv_prev:
if dtype is None:
raise RuntimeError("dtype must be provided if recv_prev is True")
if config.pipeline_dtype is None:
raise RuntimeError("pipeline_dtype must be provided if recv_prev is True")
if tensor_shape is None:
raise RuntimeError(
"tensor_shape must be specified if recv_prev is True. "
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
)
tensor_recv_prev = torch.empty(recv_prev_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=dtype)
tensor_recv_prev = torch.empty(
recv_prev_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=config.pipeline_dtype,
)
if recv_next:
if dtype is None:
if config.pipeline_dtype is None:
raise RuntimeError("dtype must be provided if recv_next is True")
if tensor_shape is None:
raise RuntimeError(
"tensor_shape must be specified if recv_next is True. "
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
)
tensor_recv_next = torch.empty(recv_next_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=dtype)
tensor_recv_next = torch.empty(
recv_next_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=config.pipeline_dtype,
)
# Send tensors in both the forward and backward directions as appropriate.
if use_ring_exchange_p2p:
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next,
tensor_recv_next=tensor_recv_next,
group=get_pipeline_model_parallel_group())
if config.use_ring_exchange_p2p:
def _ring_exchange_wrapper(**kwargs):
torch.distributed.ring_exchange(**kwargs)
return []
p2p_func = _ring_exchange_wrapper
elif config.batch_p2p_comm:
assert wait_on_reqs
p2p_func = _batched_p2p_ops
else:
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_prev,
get_pipeline_model_parallel_prev_rank())
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_prev,
get_pipeline_model_parallel_prev_rank())
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_next,
get_pipeline_model_parallel_next_rank())
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_next,
get_pipeline_model_parallel_next_rank())
ops.append(recv_next_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
p2p_func = _p2p_ops
reqs = p2p_func(
tensor_send_prev=tensor_send_prev,
tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next,
tensor_recv_next=tensor_recv_next,
group=get_pipeline_model_parallel_group(),
)
if wait_on_reqs and len(reqs) > 0:
for req in reqs:
req.wait()
reqs = None
if config.batch_p2p_comm and config.batch_p2p_sync:
# To protect against race condition when using batch_isend_irecv().
# User should assert that we have a modern enough PyTorch to not need this
torch.cuda.synchronize()
return tensor_recv_prev, tensor_recv_next
return tensor_recv_prev, tensor_recv_next, reqs
def recv_forward(tensor_shape: Shape,
dtype: torch.dtype,
timers: Callable = None) -> torch.Tensor:
def recv_forward(tensor_shape: Shape, config: ModelParallelConfig) -> torch.Tensor:
""" Receive tensor from previous rank in pipeline (forward receive).
......@@ -254,23 +357,22 @@ def recv_forward(tensor_shape: Shape,
if core.parallel_state.is_pipeline_first_stage():
input_tensor = None
else:
if timers is not None:
timers('forward-recv', log_level=2).start()
input_tensor, _ = _communicate(
if config.timers is not None:
config.timers('forward-recv', log_level=2).start()
input_tensor, _, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=True,
recv_next=False,
tensor_shape=tensor_shape,
dtype=dtype)
if timers is not None:
timers('forward-recv').stop()
config=config,
)
if config.timers is not None:
config.timers('forward-recv').stop()
return input_tensor
def recv_backward(tensor_shape: Shape,
dtype: torch.dtype,
timers: Callable = None) -> torch.Tensor:
def recv_backward(tensor_shape: Shape, config: ModelParallelConfig) -> torch.Tensor:
"""Receive tensor from next rank in pipeline (backward receive).
See _communicate for argument details.
......@@ -278,65 +380,65 @@ def recv_backward(tensor_shape: Shape,
if core.parallel_state.is_pipeline_last_stage():
output_tensor_grad = None
else:
if timers is not None:
timers('backward-recv', log_level=2).start()
_, output_tensor_grad = _communicate(
if config.timers is not None:
config.timers('backward-recv', log_level=2).start()
_, output_tensor_grad, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
tensor_shape=tensor_shape,
dtype=dtype)
if timers is not None:
timers('backward-recv').stop()
config=config,
)
if config.timers is not None:
config.timers('backward-recv').stop()
return output_tensor_grad
def send_forward(output_tensor: torch.Tensor,
timers: Callable = None) -> None:
def send_forward(output_tensor: torch.Tensor, config: ModelParallelConfig) -> None:
"""Send tensor to next rank in pipeline (forward send).
See _communicate for argument details.
"""
if not core.parallel_state.is_pipeline_last_stage():
if timers is not None:
timers('forward-send', log_level=2).start()
if config.timers is not None:
config.timers('forward-send', log_level=2).start()
_communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=False,
tensor_shape=None,
dtype=None)
if timers is not None:
timers('forward-send').stop()
config=config,
)
if config.timers is not None:
config.timers('forward-send').stop()
def send_backward(input_tensor_grad: torch.Tensor,
timers: Callable = None) -> None:
def send_backward(input_tensor_grad: torch.Tensor, config: ModelParallelConfig) -> None:
"""Send tensor to previous rank in pipeline (backward send).
See _communicate for argument details.
"""
if not core.parallel_state.is_pipeline_first_stage():
if timers is not None:
timers('backward-send', log_level=2).start()
if config.timers is not None:
config.timers('backward-send', log_level=2).start()
_communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False,
tensor_shape=None,
dtype=None)
if timers is not None:
timers('backward-send').stop()
config=config,
)
if config.timers is not None:
config.timers('backward-send').stop()
def send_forward_recv_backward(output_tensor: torch.Tensor,
tensor_shape: Shape,
dtype: torch.dtype,
timers: Callable = None) -> torch.Tensor:
def send_forward_recv_backward(
output_tensor: torch.Tensor, tensor_shape: Shape, config: ModelParallelConfig
) -> torch.Tensor:
"""Batched send and recv with next rank in pipeline.
See _communicate for argument details.
......@@ -344,24 +446,24 @@ def send_forward_recv_backward(output_tensor: torch.Tensor,
if core.parallel_state.is_pipeline_last_stage():
output_tensor_grad = None
else:
if timers is not None:
timers('forward-send-backward-recv', log_level=2).start()
_, output_tensor_grad = _communicate(
if config.timers is not None:
config.timers('forward-send-backward-recv', log_level=2).start()
_, output_tensor_grad, _ = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
tensor_shape=tensor_shape,
dtype=dtype)
if timers is not None:
timers('forward-send-backward-recv').stop()
config=config,
)
if config.timers is not None:
config.timers('forward-send-backward-recv').stop()
return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad: torch.Tensor,
tensor_shape: Shape,
dtype: torch.dtype,
timers: Callable = None) -> torch.Tensor:
def send_backward_recv_forward(
input_tensor_grad: torch.Tensor, tensor_shape: Shape, config: ModelParallelConfig
) -> torch.Tensor:
"""Batched send and recv with previous rank in pipeline.
See _communicate for argument details.
......@@ -369,88 +471,101 @@ def send_backward_recv_forward(input_tensor_grad: torch.Tensor,
if core.parallel_state.is_pipeline_first_stage():
input_tensor = None
else:
if timers is not None:
timers('backward-send-forward-recv', log_level=2).start()
input_tensor, _ = _communicate(
if config.timers is not None:
config.timers('backward-send-forward-recv', log_level=2).start()
input_tensor, _, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=True,
recv_next=False,
tensor_shape=tensor_shape,
dtype=dtype)
if timers is not None:
timers('backward-send-forward-recv').stop()
config=config,
)
if config.timers is not None:
config.timers('backward-send-forward-recv').stop()
return input_tensor
def send_forward_recv_forward(output_tensor: torch.Tensor,
recv_prev: bool,
tensor_shape: Shape,
dtype: torch.dtype,
timers: Callable = None) -> torch.Tensor:
def send_forward_recv_forward(
output_tensor: torch.Tensor,
recv_prev: bool,
tensor_shape: Shape,
config: ModelParallelConfig,
overlap_p2p_comm: bool = False,
) -> torch.Tensor:
"""Batched recv from previous rank and send to next rank in pipeline.
See _communicate for argument details.
"""
if timers is not None:
timers('forward-send-forward-recv', log_level=2).start()
input_tensor, _ = _communicate(
if config.timers is not None:
config.timers('forward-send-forward-recv', log_level=2).start()
input_tensor, _, wait_handles = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=False,
tensor_shape=tensor_shape,
dtype=dtype)
if timers is not None:
timers('forward-send-forward-recv').stop()
wait_on_reqs=(not overlap_p2p_comm),
config=config,
)
if config.timers is not None:
config.timers('forward-send-forward-recv').stop()
if overlap_p2p_comm:
return input_tensor, wait_handles
return input_tensor
def send_backward_recv_backward(input_tensor_grad: torch.Tensor,
recv_next: bool,
tensor_shape: Shape,
dtype: torch.dtype,
timers: Callable = None) -> torch.Tensor:
def send_backward_recv_backward(
input_tensor_grad: torch.Tensor,
recv_next: bool,
tensor_shape: Shape,
config: ModelParallelConfig,
overlap_p2p_comm: bool = False,
) -> torch.Tensor:
"""Batched recv from next rank and send to previous rank in pipeline.
See _communicate for argument details.
"""
if timers is not None:
timers('backward-send-backward-recv', log_level=2).start()
_, output_tensor_grad = _communicate(
if config.timers is not None:
config.timers('backward-send-backward-recv', log_level=2).start()
_, output_tensor_grad, wait_handles = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype=dtype)
if timers is not None:
timers('backward-send-backward-recv').stop()
wait_on_reqs=(not overlap_p2p_comm),
config=config,
)
if config.timers is not None:
config.timers('backward-send-backward-recv').stop()
if overlap_p2p_comm:
return output_tensor_grad, wait_handles
return output_tensor_grad
def send_forward_backward_recv_forward_backward(
output_tensor: torch.Tensor,
input_tensor_grad: torch.Tensor,
recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
dtype: torch.dtype,
timers: Callable = None) -> Tuple[torch.Tensor, torch.Tensor]:
output_tensor: torch.Tensor,
input_tensor_grad: torch.Tensor,
recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
config: ModelParallelConfig,
) -> torch.Tensor:
"""Batched send and recv with previous and next ranks in pipeline.
See _communicate for argument details.
"""
if timers is not None:
timers('forward-backward-send-forward-backward-recv',
log_level=2).start()
input_tensor, output_tensor_grad = _communicate(
if config.timers is not None:
config.timers('forward-backward-send-forward-backward-recv', log_level=2).start()
input_tensor, output_tensor_grad, _ = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype=dtype)
if timers is not None:
timers('forward-backward-send-forward-backward-recv').stop()
config=config,
)
if config.timers is not None:
config.timers('forward-backward-send-forward-backward-recv').stop()
return input_tensor, output_tensor_grad
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from contextlib import contextmanager, nullcontext
from typing import Optional, List, Union, Callable, Any
import contextlib
from typing import Callable, Iterator, List, Optional, Union
import torch
from torch.autograd.variable import Variable
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import core
from megatron.core import parallel_state
from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.enums import ModelType
from megatron.core.utils import get_attr_wrapped_model, get_model_type
from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.utils import get_attr_wrapped_model, get_model_config, get_model_type
# Types
Shape = Union[List[int], torch.Size]
def get_forward_backward_func():
"""Retrieves the appropriate forward_backward function given the
configuration of parallel_state.
......@@ -24,6 +26,10 @@ def get_forward_backward_func():
world size and virtual pipeline model parallel world size in the
global parallel_state.
Note that if using sequence parallelism, the sequence length component of
the tensor shape is updated to original_sequence_length /
tensor_model_parallel_world_size.
The function returned takes the following arguments:
forward_step_func (required): A function that takes a data
......@@ -32,6 +38,13 @@ def get_forward_backward_func():
take one torch.Tensor and return a torch.Tensor of loss and a
dictionary of string -> torch.Tensor.
A third argument, checkpoint_activations_microbatch, indicates
that the activations for this microbatch should be
checkpointed. A None value for this argument indicates that
the default from the configuration should be used. This is
used when the
num_microbatches_with_partial_activation_checkpoints is used.
For example:
def loss_func(loss_mask, output_tensor):
......@@ -54,44 +67,28 @@ def get_forward_backward_func():
data_iterator (required): an iterator over the data, will be
passed as is to forward_step_func
passed as is to forward_step_func. Expected to be a list of
iterators in the case of interleaved pipeline parallelism.
model (required): the actual model. A torch.nn.Module or, in the
case or iterleaving, a list of torch.nn.Module
model (required): the actual model. Expected to be a list of modules in the case of interleaved
pipeline parallelism. Must be a (potentially wrapped) megatron.core.models.MegatronModule.
num_microbatches (int, required):
The number of microbatches to go through
dtype (required when using pipeline parallelism): dtype used in
p2p communication, usually params_dtype
tensor_shape (required when using pipeline parallelism): Shape of
tensor. The tensor is expected to be 3D and its order of
dimension is supposed to be ``(sequence, batch, hidden)``.
seq_length (int, required): Sequence length of the current global batch. If this is a dual-stack
transformer, this is the encoder's sequence length. This is ignored if variable_seq_lengths
in the config is True. Otherwise, each microbatch in the current global batch size must use
this sequence length.
decoder_seq_length (int, required for ModelType.encoder_and_decoder models):
Sequence length of the decoder portion, used to determine tensor shapes.
micro_batch_size (int, required): The number of sequences in a microbatch.
grad_scaler (optional, default=None): If using loss scaling,
this function should take the loss and return the scaled
loss. If None, no function is called on the loss.
decoder_seq_length (int, optional): The sequence length for the decoder in a dual-stack
transformer. This is ignored for a single-stack transformer.
sequence_parallel (optional, default=False):
Set to :obj:`True` for this function to handle sequence
length. When :obj:`True`, the sequence length on each tensor
model parallel rank is updated to
:math:`original\_sequence\_length /
tensor\_model\_parallel\_world\_size`.
TODO: Do we need this? Just roll into tensor_shape arg?
forward_only (optional, default = False): Perform only the forward step
forward_only (optional, default=False): Perform only the forward step
timers (optional, default=None): TODO
collect_non_loss_data: TODO
enable_autocast (optional, default=False): If True, runs the
forward_step_func call inside torch.autocast context
collect_non_loss_data (optional, bool, default=False): TODO
"""
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
......@@ -104,24 +101,20 @@ def get_forward_backward_func():
forward_backward_func = forward_backward_no_pipelining
return forward_backward_func
def deallocate_output_tensor(out):
def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
'''Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
This method should be called right after the output tensor has been
sent to the next pipeline stage. At this point, the output tensor is
only useful for its '.grad_fn' field, and not its '.data'.
'''
if out is None:
if (out is None) or (not deallocate_pipeline_outputs):
return
assert isinstance(out, torch.Tensor), \
"expected Tensor, found %s." % type(out).__name__
assert out._base is None, \
"counter-productive to free a view of another tensor."
out.data = torch.empty(
(1,),
device = out.device,
dtype = out.dtype,
)
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
assert out._base is None, "counter-productive to free a view of another tensor."
out.data = torch.empty((1,), device=out.device, dtype=out.dtype,)
def custom_backward(output, grad_output):
'''Directly call C++ autograd engine.
......@@ -132,53 +125,48 @@ def custom_backward(output, grad_output):
grad have the same shape, while C++'s 'backward' does not.
'''
assert output.numel() == 1, \
"output should be pseudo-'freed' in schedule, to optimize memory"
assert isinstance(output, torch.Tensor), \
"output == '%s'." % type(output).__name__
assert isinstance(grad_output, (torch.Tensor, type(None))), \
assert output.numel() == 1, "output should be pseudo-'freed' in schedule, to optimize memory"
assert isinstance(output, torch.Tensor), "output == '%s'." % type(output).__name__
assert isinstance(grad_output, (torch.Tensor, type(None))), (
"grad_output == '%s'." % type(grad_output).__name__
)
# Handle scalar output
if grad_output is None:
assert output.numel() == 1, "implicit grad requires scalar output."
grad_output = torch.ones_like(
output,
memory_format = torch.preserve_format,
)
grad_output = torch.ones_like(output, memory_format=torch.preserve_format,)
# Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Variable._execution_engine.run_backward(
tensors = (output,),
grad_tensors = (grad_output,),
keep_graph = False,
create_graph = False,
inputs = tuple(),
tensors=(output,),
grad_tensors=(grad_output,),
keep_graph=False,
create_graph=False,
inputs=tuple(),
allow_unreachable=True,
accumulate_grad=True,
)
def forward_step(forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
timers,
collect_non_loss_data=False,
enable_autocast=False):
def forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data=False,
checkpoint_activations_microbatch=None,
):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor."""
if timers is not None:
timers('forward-compute', log_level=2).start()
if config.timers is not None:
config.timers('forward-compute', log_level=2).start()
unwrap_output_tensor = False
if not isinstance(input_tensor, list):
......@@ -188,9 +176,17 @@ def forward_step(forward_step_func,
set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
set_input_tensor(input_tensor)
context_manager = torch.autocast("cuda") if enable_autocast else nullcontext()
if config.enable_autocast:
context_manager = torch.autocast("cuda", dtype=config.autocast_dtype)
else:
context_manager = contextlib.nullcontext()
with context_manager:
output_tensor, loss_func = forward_step_func(data_iterator, model)
if checkpoint_activations_microbatch is None:
output_tensor, loss_func = forward_step_func(data_iterator, model)
else:
output_tensor, loss_func = forward_step_func(
data_iterator, model, checkpoint_activations_microbatch
)
if parallel_state.is_pipeline_last_stage():
if not collect_non_loss_data:
......@@ -202,24 +198,24 @@ def forward_step(forward_step_func,
data = loss_func(output_tensor, non_loss_data=True)
forward_data_store.append(data)
if timers is not None:
timers('forward-compute').stop()
if config.timers is not None:
config.timers('forward-compute').stop()
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
model_type = get_model_type(model)
if parallel_state.is_pipeline_stage_after_split() and \
model_type == ModelType.encoder_and_decoder:
if (
parallel_state.is_pipeline_stage_after_split()
and model_type == ModelType.encoder_and_decoder
):
return [output_tensor, input_tensor[-1]]
if unwrap_output_tensor:
return output_tensor
return [output_tensor]
def backward_step(grad_scaler, input_tensor, output_tensor,
output_tensor_grad, model_type, timers):
def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
......@@ -232,8 +228,8 @@ def backward_step(grad_scaler, input_tensor, output_tensor,
# needs to be modified slightly to support arbitrary numbers of skip
# connections.
if timers is not None:
timers('backward-compute', log_level=2).start()
if config.timers is not None:
config.timers('backward-compute', log_level=2).start()
# Retain the grad on the input_tensor.
unwrap_input_tensor_grad = False
......@@ -250,9 +246,13 @@ def backward_step(grad_scaler, input_tensor, output_tensor,
output_tensor_grad = [output_tensor_grad]
# Backward pass.
if output_tensor_grad[0] is None and grad_scaler is not None:
output_tensor = grad_scaler(output_tensor[0])
custom_backward(output_tensor[0], output_tensor_grad[0])
if output_tensor_grad[0] is None and config.grad_scale_func is not None:
output_tensor[0] = config.grad_scale_func(output_tensor[0])
if config.deallocate_pipeline_outputs:
custom_backward(output_tensor[0], output_tensor_grad[0])
else:
torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
# Collect the grad of the input_tensor.
input_tensor_grad = [None]
......@@ -266,42 +266,34 @@ def backward_step(grad_scaler, input_tensor, output_tensor,
# Handle single skip connection if it exists (encoder_hidden_state in
# model with encoder and decoder).
if parallel_state.get_pipeline_model_parallel_world_size() > 1 and \
parallel_state.is_pipeline_stage_after_split() and \
model_type == ModelType.encoder_and_decoder:
if (
parallel_state.get_pipeline_model_parallel_world_size() > 1
and parallel_state.is_pipeline_stage_after_split()
and model_type == ModelType.encoder_and_decoder
):
if output_tensor_grad[1] is not None:
input_tensor_grad[-1].add_(output_tensor_grad[1])
if unwrap_input_tensor_grad:
input_tensor_grad = input_tensor_grad[0]
if timers is not None:
timers('backward-compute').stop()
if config.timers is not None:
config.timers('backward-compute').stop()
return input_tensor_grad
@contextmanager
def dummy_handler():
try:
yield
finally:
pass
def forward_backward_no_pipelining(*,
forward_step_func,
data_iterator,
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
dtype: Optional[torch.dtype] = None, # unused
tensor_shape: Optional[Shape] = None, # unused
decoder_seq_length: Optional[int] = None, # unused
grad_scaler: Callable = None,
sequence_parallel: bool = False, # unused
forward_only: bool = False,
timers: Callable = None,
collect_non_loss_data: bool = False,
enable_autocast: bool = False):
def forward_backward_no_pipelining(
*,
forward_step_func,
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
seq_length: int, # unused
micro_batch_size: int, # unused
decoder_seq_length: int = None, # unused
forward_only: bool = False,
collect_non_loss_data: bool = False,
):
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
......@@ -310,57 +302,121 @@ def forward_backward_no_pipelining(*,
See get_forward_backward_func() for argument details
"""
assert len(model) == 1
model = model[0]
context_handler = dummy_handler
if isinstance(model, torchDDP):
context_handler = model.no_sync
if isinstance(model, list):
assert len(model) == 1, "non-pipeline-parallel schedule does not support model chunking"
model = model[0]
if isinstance(data_iterator, list):
assert (
len(data_iterator) == 1
), "non-pipeline-parallel schedule does not support model chunking"
data_iterator = data_iterator[0]
config = get_model_config(model)
no_sync_func = config.no_sync_func
if no_sync_func is None and isinstance(model, torchDDP):
no_sync_func = model.no_sync
if no_sync_func is None:
no_sync_func = contextlib.nullcontext
model_type = get_model_type(model)
forward_data_store = []
input_tensor, output_tensor_grad = None, None
with context_handler():
with no_sync_func():
for i in range(num_microbatches - 1):
output_tensor = forward_step(forward_step_func, data_iterator,
model, num_microbatches, input_tensor, forward_data_store,
timers, collect_non_loss_data, enable_autocast)
output_tensor = forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
)
if not forward_only:
backward_step(grad_scaler, input_tensor, output_tensor,
output_tensor_grad, model_type, timers)
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor = forward_step(forward_step_func, data_iterator,
model, num_microbatches, input_tensor, forward_data_store,
timers, collect_non_loss_data, enable_autocast)
output_tensor = forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
)
if not forward_only:
backward_step(grad_scaler, input_tensor, output_tensor,
output_tensor_grad, model_type, timers)
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)
return forward_data_store
def forward_backward_pipelining_with_interleaving(*,
forward_step_func,
data_iterator,
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
dtype: torch.dtype,
tensor_shape: Shape,
decoder_seq_length: Optional[int] = None,
grad_scaler: Callable = None,
sequence_parallel: bool = False,
forward_only: bool = False,
timers: Callable = None,
collect_non_loss_data: bool = False,
enable_autocast: bool = False):
def forward_backward_pipelining_with_interleaving(
*,
forward_step_func,
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
seq_length: int,
micro_batch_size: int,
decoder_seq_length: int = None,
forward_only: bool = False,
collect_non_loss_data: bool = False,
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
assert isinstance(model, list), "interleaved pipeline parallelism expected model chunking"
assert all(isinstance(chunk, torch.nn.Module) for chunk in model), "invalid model chunking"
assert isinstance(
data_iterator, list
), "interleaved pipeline parallelism expected each model chunk to have a data iterator"
config = get_model_config(model[0])
if config.overlap_p2p_comm and config.batch_p2p_comm:
raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm")
# Disable async grad reductions
no_sync_func = config.no_sync_func
if no_sync_func is None and all(isinstance(chunk, torchDDP) for chunk in model):
def multi_no_sync():
stack = contextlib.ExitStack()
for chunk in model:
stack.enter_context(chunk.no_sync())
return stack
no_sync_func = multi_no_sync
if no_sync_func is None:
no_sync_func = contextlib.nullcontext
no_sync_context = None
def disable_grad_sync():
"""Disable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is None:
no_sync_context = no_sync_func()
no_sync_context.__enter__()
def enable_grad_sync():
"""Enable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is not None:
no_sync_context.__exit__(None, None, None)
no_sync_context = None
disable_grad_sync()
# Model chunk IDs with synchronized grads
synchronized_model_chunks = set()
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
......@@ -381,17 +437,15 @@ def forward_backward_pipelining_with_interleaving(*,
if model_type == ModelType.encoder_and_decoder:
raise RuntimeError("Interleaving is not supported with an encoder and decoder model.")
if decoder_seq_length is not None and decoder_seq_length != tensor_shape[0]:
raise RuntimeError("Interleaving is not supported with a different decoder sequence length.")
if sequence_parallel:
seq_length, batch_size, hidden = tensor_shape
tensor_shape = (
seq_length // parallel_state.get_tensor_model_parallel_world_size(),
batch_size,
hidden,
if decoder_seq_length is not None and decoder_seq_length != seq_length:
raise RuntimeError(
"Interleaving is not supported with a different decoder sequence length."
)
tensor_shape = [seq_length, micro_batch_size, config.hidden_size]
if config.sequence_parallel:
tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size()
# Compute number of warmup and remaining microbatches.
num_model_chunks = len(model)
total_num_microbatches = num_microbatches * num_model_chunks
......@@ -409,45 +463,96 @@ def forward_backward_pipelining_with_interleaving(*,
num_warmup_microbatches = total_num_microbatches
all_warmup_microbatches = True
else:
num_warmup_microbatches = \
(pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
num_warmup_microbatches += (
num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches,
total_num_microbatches)
num_microbatches_remaining = \
total_num_microbatches - num_warmup_microbatches
num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches, total_num_microbatches)
num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches
# Checkpoint the activations of partial Transformer layers in a number of micro-batches
# within the maximum outstanding micro-batch backpropagations.
# Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
# checkpoint partial Transformer layers (or skip checkpointing) and
# the rest of micro-batches within a window of micro-batches checkpoint
# all Transformer layers. The window of micro-batches is set by the maximum
# outstanding backpropagations and becomes smaller at later pipeline stages.
# Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
max_outstanding_backprops = None
if config.num_microbatches_with_partial_activation_checkpoints is not None:
max_outstanding_backprops = num_warmup_microbatches + 1
# Synchronize params for first two model chunks
if config.param_sync_func is not None:
config.param_sync_func(model[0].parameters())
config.param_sync_func(model[1].parameters())
def get_model_chunk_id(microbatch_id, forward):
"""Helper method to get the model chunk ID given the iteration number."""
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
if not forward:
model_chunk_id = (num_model_chunks - model_chunk_id - 1)
model_chunk_id = num_model_chunks - model_chunk_id - 1
return model_chunk_id
def forward_step_helper(microbatch_id):
def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool:
"""Check if an iteration is the first for a model chunk."""
microbatch_group_size = pipeline_parallel_size * num_model_chunks
num_microbatch_groups = total_num_microbatches // microbatch_group_size
microbatch_group_id = microbatch_id // microbatch_group_size
microbatch_id_in_group = microbatch_id % microbatch_group_size
if microbatch_group_id == 0:
return microbatch_id_in_group % pipeline_parallel_size == 0
else:
return False
def is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool:
"""Check if an iteration is the last for a model chunk."""
microbatch_group_size = pipeline_parallel_size * num_model_chunks
num_microbatch_groups = total_num_microbatches // microbatch_group_size
microbatch_group_id = microbatch_id // microbatch_group_size
microbatch_id_in_group = microbatch_id % microbatch_group_size
if microbatch_group_id == num_microbatch_groups - 1:
return microbatch_id_in_group % pipeline_parallel_size == pipeline_parallel_size - 1
else:
return False
def forward_step_helper(microbatch_id, checkpoint_activations_microbatch):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
# launch param synchronization for next model chunk
# Note: Asynchronous communication tends to slow down compute.
# To reduce idling from mismatched microbatch times, we launch
# asynchronous communication at the same time across the
# pipeline-parallel group.
if config.param_sync_func is not None:
param_sync_microbatch_id = microbatch_id + pipeline_parallel_rank
if (
param_sync_microbatch_id < total_num_microbatches
and is_first_microbatch_for_model_chunk(param_sync_microbatch_id)
):
param_sync_chunk_id = get_model_chunk_id(param_sync_microbatch_id, forward=True) + 1
if 1 < param_sync_chunk_id < num_model_chunks:
config.param_sync_func(model[param_sync_chunk_id].parameters())
# forward step
if parallel_state.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == \
len(output_tensors[model_chunk_id]):
if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1]
output_tensor = forward_step(forward_step_func,
data_iterator[model_chunk_id],
model[model_chunk_id],
num_microbatches,
input_tensor,
forward_data_store,
timers,
collect_non_loss_data,
enable_autocast)
output_tensor = forward_step(
forward_step_func,
data_iterator[model_chunk_id],
model[model_chunk_id],
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
)
output_tensors[model_chunk_id].append(output_tensor)
# if forward-only, no need to save tensors for a backward pass
......@@ -464,31 +569,65 @@ def forward_backward_pipelining_with_interleaving(*,
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
# launch grad synchronization (default)
if config.grad_sync_func is None and is_last_microbatch_for_model_chunk(microbatch_id):
enable_grad_sync()
synchronized_model_chunks.add(model_chunk_id)
if parallel_state.is_pipeline_last_stage():
if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = \
backward_step(grad_scaler,
input_tensor,
output_tensor,
output_tensor_grad,
model_type,
timers)
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad, model_type, config
)
# launch grad synchronization (custom grad sync)
# Note: Asynchronous communication tends to slow down compute.
# To reduce idling from mismatched microbatch times, we launch
# asynchronous communication at the same time across the
# pipeline-parallel group.
if config.grad_sync_func is not None:
grad_sync_microbatch_id = microbatch_id - pipeline_parallel_rank
if grad_sync_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(
grad_sync_microbatch_id
):
grad_sync_chunk_id = get_model_chunk_id(grad_sync_microbatch_id, forward=False)
enable_grad_sync()
config.grad_sync_func(model[grad_sync_chunk_id].parameters())
synchronized_model_chunks.add(grad_sync_chunk_id)
disable_grad_sync()
return input_tensor_grad
# Run warmup forward passes.
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(
p2p_communication.recv_forward(tensor_shape, dtype, timers=timers))
input_tensors[0].append(p2p_communication.recv_forward(tensor_shape, config))
fwd_wait_handles = None
bwd_wait_handles = None
for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k)
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
# Decide to checkpoint all layers' activations of the current micro-batch
if max_outstanding_backprops is not None:
checkpoint_activations_microbatch = (
k % max_outstanding_backprops
>= config.num_microbatches_with_partial_activation_checkpoints
)
else:
checkpoint_activations_microbatch = None
output_tensor = forward_step_helper(k, checkpoint_activations_microbatch)
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
recv_prev = True
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
if next_forward_model_chunk_id == 0:
......@@ -502,108 +641,255 @@ def forward_backward_pipelining_with_interleaving(*,
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if k == (num_warmup_microbatches - 1) and not forward_only and \
not all_warmup_microbatches:
input_tensor_grad = None
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
input_tensor, output_tensor_grad = \
p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next,
tensor_shape=tensor_shape, dtype=dtype,
timers=timers)
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
if not config.overlap_p2p_comm:
if (
k == (num_warmup_microbatches - 1)
and not forward_only
and not all_warmup_microbatches
):
input_tensor_grad = None
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
(
input_tensor,
output_tensor_grad,
) = p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor,
input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
else:
input_tensor = p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, config=config
)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
else:
input_tensor = \
p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev,
tensor_shape=tensor_shape, dtype=dtype,
timers=timers)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
deallocate_output_tensor(output_tensor)
input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward(
output_tensor,
recv_prev=recv_prev,
tensor_shape=tensor_shape,
config=config,
overlap_p2p_comm=True,
)
if (
k == (num_warmup_microbatches - 1)
and not forward_only
and not all_warmup_microbatches
):
input_tensor_grad = None
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
(
output_tensor_grad,
bwd_wait_handles,
) = p2p_communication.send_backward_recv_backward(
input_tensor_grad,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
overlap_p2p_comm=True,
)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
# Run 1F1B in steady state.
for k in range(num_microbatches_remaining):
# Forward pass.
forward_k = k + num_warmup_microbatches
output_tensor = forward_step_helper(forward_k)
# Backward pass.
backward_k = k
input_tensor_grad = backward_step_helper(backward_k)
# Decide to checkpoint all layers' activations of the current micro-batch
if max_outstanding_backprops is not None:
checkpoint_activations_microbatch = (
forward_k % max_outstanding_backprops
>= config.num_microbatches_with_partial_activation_checkpoints
)
else:
checkpoint_activations_microbatch = None
if config.overlap_p2p_comm:
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
output_tensor = forward_step_helper(forward_k, checkpoint_activations_microbatch)
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
# Last virtual stage no activation tensor to send
if parallel_state.is_pipeline_last_stage():
output_tensor = None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev = True
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id = get_model_chunk_id(
forward_k - (pipeline_parallel_size - 1), forward=True
)
if next_forward_model_chunk_id == (num_model_chunks - 1):
recv_prev = False
next_forward_model_chunk_id += 1
else:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if k == (num_microbatches_remaining - 1):
recv_prev = False
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if parallel_state.is_pipeline_last_stage():
output_tensor = None
# Send activation tensor to the next stage and receive activation tensor from the
# previous stage
input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward(
output_tensor,
recv_prev=recv_prev,
tensor_shape=tensor_shape,
config=config,
overlap_p2p_comm=True,
)
# assert fwd_wait_handles is not None
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
# Backward pass.
backward_k = k
input_tensor_grad = backward_step_helper(backward_k)
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
# First virtual stage no activation gradient tensor to send
if parallel_state.is_pipeline_first_stage():
input_tensor_grad = None
# Determine if the current virtual stage has an activation gradient tensor to receive
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id = get_model_chunk_id(
backward_k - (pipeline_parallel_size - 1), forward=False
)
if next_backward_model_chunk_id == 0:
recv_next = False
next_backward_model_chunk_id -= 1
else:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
output_tensor_grad, bwd_wait_handles = p2p_communication.send_backward_recv_backward(
input_tensor_grad,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
overlap_p2p_comm=True,
)
else: # no p2p overlap
output_tensor = forward_step_helper(forward_k, checkpoint_activations_microbatch)
# Backward pass.
backward_k = k
input_tensor_grad = backward_step_helper(backward_k)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if parallel_state.is_pipeline_last_stage():
output_tensor = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if parallel_state.is_pipeline_first_stage():
input_tensor_grad = None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev = True
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id = get_model_chunk_id(
forward_k - (pipeline_parallel_size - 1), forward=True
)
if next_forward_model_chunk_id == (num_model_chunks - 1):
recv_prev = False
next_forward_model_chunk_id += 1
else:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if parallel_state.is_pipeline_first_stage():
input_tensor_grad = None
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id = get_model_chunk_id(
backward_k - (pipeline_parallel_size - 1), forward=False
)
if next_backward_model_chunk_id == 0:
recv_next = False
next_backward_model_chunk_id -= 1
else:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev = True
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id = get_model_chunk_id(
forward_k - (pipeline_parallel_size - 1), forward=True)
if next_forward_model_chunk_id == (num_model_chunks - 1):
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if k == (num_microbatches_remaining - 1):
recv_prev = False
next_forward_model_chunk_id += 1
else:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
forward=True)
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id = get_model_chunk_id(
backward_k - (pipeline_parallel_size - 1), forward=False)
if next_backward_model_chunk_id == 0:
recv_next = False
next_backward_model_chunk_id -= 1
else:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
forward=False)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if k == (num_microbatches_remaining - 1):
recv_prev = False
# Communicate tensors.
input_tensor, output_tensor_grad = \
p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next,
tensor_shape=tensor_shape, dtype=dtype, timers=timers)
deallocate_output_tensor(output_tensor)
# Communicate tensors.
(
input_tensor,
output_tensor_grad,
) = p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor,
input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
)
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if recv_prev:
input_tensors[next_forward_model_chunk_id].append(input_tensor)
if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(
output_tensor_grad)
output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
# Run cooldown backward passes (flush out pipeline).
if not forward_only:
if config.overlap_p2p_comm and bwd_wait_handles is not None:
for wait_handle in bwd_wait_handles:
wait_handle.wait()
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append(
p2p_communication.recv_backward(tensor_shape, dtype=dtype, timers=timers))
output_tensor_grads[num_model_chunks - 1].append(
p2p_communication.recv_backward(tensor_shape, config=config)
)
for k in range(num_microbatches_remaining, total_num_microbatches):
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
if next_backward_model_chunk_id == (num_model_chunks - 1):
......@@ -612,18 +898,33 @@ def forward_backward_pipelining_with_interleaving(*,
recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next=recv_next,
tensor_shape=tensor_shape, dtype=dtype,
timers=timers))
input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, config=config
)
)
# Launch any remaining grad reductions
enable_grad_sync()
if config.grad_sync_func is not None:
params = []
for model_chunk_id in range(num_model_chunks):
if model_chunk_id not in synchronized_model_chunks:
params.extend(model[model_chunk_id].parameters())
synchronized_model_chunks.add(model_chunk_id)
if params:
config.grad_sync_func(params)
return forward_data_store
def get_tensor_shapes(*,
rank: int,
model_type: ModelType,
tensor_shape: Shape,
decoder_seq_length: int,
sequence_parallel: bool):
def get_tensor_shapes(
*,
rank: int,
model_type: ModelType,
seq_length: int,
micro_batch_size: int,
decoder_seq_length: int,
config,
):
# Determine right tensor sizes (based on position of rank with respect to split
# rank) and model size.
# Send two tensors if model is T5 and rank is in decoder stage:
......@@ -634,71 +935,63 @@ def get_tensor_shapes(*,
# Otherwise, send one tensor (pre-transpose).
tensor_shapes = []
assert (
len(tensor_shape) == 3
), f"`tensor_shape` should be [sequence_length, micro_batch_size, hidden_size] but {tensor_shape}"
seq_length, micro_batch_size, hidden_size = tensor_shape
if sequence_parallel:
if config.sequence_parallel:
seq_length = seq_length // parallel_state.get_tensor_model_parallel_world_size()
if model_type == ModelType.encoder_and_decoder:
decoder_seq_length = (
decoder_seq_length // parallel_state.get_tensor_model_parallel_world_size()
)
if model_type == ModelType.encoder_and_decoder:
if sequence_parallel:
decoder_seq_length = decoder_seq_length // parallel_state.get_tensor_model_parallel_world_size()
if parallel_state.is_pipeline_stage_before_split(rank):
tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
else:
tensor_shapes.append((decoder_seq_length, micro_batch_size, hidden_size))
tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
tensor_shapes.append((decoder_seq_length, micro_batch_size, config.hidden_size))
tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
else:
tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
return tensor_shapes
def recv_forward(tensor_shapes, dtype, timers):
def recv_forward(tensor_shapes, config):
input_tensors = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
input_tensors.append(None)
else:
input_tensors.append(p2p_communication.recv_forward(tensor_shape, dtype,
timers=timers))
input_tensors.append(p2p_communication.recv_forward(tensor_shape, config))
return input_tensors
def recv_backward(tensor_shapes, dtype, timers):
def recv_backward(tensor_shapes, config):
output_tensor_grads = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
output_tensor_grads.append(None)
else:
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape, dtype,
timers=timers))
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape, config))
return output_tensor_grads
def send_forward(output_tensors, tensor_shapes, timers):
def send_forward(output_tensors, tensor_shapes, config):
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_forward(output_tensor, timers=timers)
p2p_communication.send_forward(output_tensor, config)
def send_backward(input_tensor_grads, tensor_shapes, timers):
def send_backward(input_tensor_grads, tensor_shapes, config):
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_backward(input_tensor_grad, timers=timers)
p2p_communication.send_backward(input_tensor_grad, config)
def send_forward_recv_backward(output_tensors, tensor_shapes, dtype, timers):
def send_forward_recv_backward(output_tensors, tensor_shapes, config):
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
output_tensor_grads = []
......@@ -707,12 +1000,13 @@ def send_forward_recv_backward(output_tensors, tensor_shapes, dtype, timers):
output_tensor_grads.append(None)
continue
output_tensor_grad = p2p_communication.send_forward_recv_backward(
output_tensor, tensor_shape, dtype, timers=timers)
output_tensor, tensor_shape, config
)
output_tensor_grads.append(output_tensor_grad)
return output_tensor_grads
def send_backward_recv_forward(input_tensor_grads, tensor_shapes, dtype, timers):
def send_backward_recv_forward(input_tensor_grads, tensor_shapes, config):
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
input_tensors = []
......@@ -721,56 +1015,110 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, dtype, timers)
input_tensors.append(None)
continue
input_tensor = p2p_communication.send_backward_recv_forward(
input_tensor_grad, tensor_shape, dtype, timers=timers)
input_tensor_grad, tensor_shape, config
)
input_tensors.append(input_tensor)
return input_tensors
def forward_backward_pipelining_without_interleaving(*,
forward_step_func,
data_iterator,
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
dtype: torch.dtype,
tensor_shape: Shape,
decoder_seq_length: Optional[int] = None,
grad_scaler: Callable = None,
sequence_parallel: bool = False,
forward_only: bool = False,
timers: Callable = None,
collect_non_loss_data: bool = False,
enable_autocast: bool = False):
def forward_backward_pipelining_without_interleaving(
*,
forward_step_func,
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
seq_length: int,
micro_batch_size: int,
decoder_seq_length: int = None,
forward_only: bool = False,
collect_non_loss_data: bool = False,
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
Returns dictionary with losses if the last stage, empty dict otherwise."""
assert len(model) == 1
model = model[0]
if isinstance(model, list):
assert (
len(model) == 1
), "non-interleaved pipeline parallelism does not support model chunking"
model = model[0]
if isinstance(data_iterator, list):
assert (
len(data_iterator) == 1
), "non-pipeline-parallel schedule does not support model chunking"
data_iterator = data_iterator[0]
config = get_model_config(model)
if config.overlap_p2p_comm:
raise ValueError(
"Non-interleaved pipeline parallelism does not support overlapping p2p communication"
)
# Disable async grad reductions
no_sync_func = config.no_sync_func
if no_sync_func is None and isinstance(model, torchDDP):
no_sync_func = model.no_sync
if no_sync_func is None:
no_sync_func = contextlib.nullcontext
no_sync_context = None
def disable_grad_sync():
"""Disable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is None:
no_sync_context = no_sync_func()
no_sync_context.__enter__()
def enable_grad_sync():
"""Enable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is not None:
no_sync_context.__exit__(None, None, None)
no_sync_context = None
disable_grad_sync()
# Compute number of warmup microbatches.
num_warmup_microbatches = \
(parallel_state.get_pipeline_model_parallel_world_size() -
parallel_state.get_pipeline_model_parallel_rank() - 1)
num_warmup_microbatches = min(
num_warmup_microbatches,
num_microbatches)
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
num_warmup_microbatches = (
parallel_state.get_pipeline_model_parallel_world_size()
- parallel_state.get_pipeline_model_parallel_rank()
- 1
)
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
# Checkpoint the activations of partial Transformer layers in a number of micro-batches
# within the maximum outstanding micro-batch backpropagations.
# Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
# checkpoint partial Transformer layers (or skip checkpointing) and
# the rest of micro-batches within a window of micro-batches checkpoint
# all Transformer layers. The window of micro-batches is set by the maximum
# outstanding backpropagations and becomes smaller at later pipeline stages.
# Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
max_outstanding_backprops = None
if config.num_microbatches_with_partial_activation_checkpoints is not None:
max_outstanding_backprops = num_warmup_microbatches + 1
model_type = get_model_type(model)
rank = parallel_state.get_pipeline_model_parallel_rank()
recv_tensor_shapes = get_tensor_shapes(rank=rank-1,
model_type=model_type,
tensor_shape=tensor_shape,
decoder_seq_length=decoder_seq_length,
sequence_parallel=sequence_parallel)
send_tensor_shapes = get_tensor_shapes(rank=rank,
model_type=model_type,
tensor_shape=tensor_shape,
decoder_seq_length=decoder_seq_length,
sequence_parallel=sequence_parallel)
recv_tensor_shapes = get_tensor_shapes(
rank=rank - 1,
model_type=model_type,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
decoder_seq_length=decoder_seq_length,
config=config,
)
send_tensor_shapes = get_tensor_shapes(
rank=rank,
model_type=model_type,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
decoder_seq_length=decoder_seq_length,
config=config,
)
# Input, output tensors only need to be saved when doing backward passes
input_tensors = None
......@@ -782,77 +1130,125 @@ def forward_backward_pipelining_without_interleaving(*,
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_tensor = recv_forward(recv_tensor_shapes, dtype, timers=timers)
output_tensor = forward_step(forward_step_func, data_iterator, model, num_microbatches,
input_tensor, forward_data_store,
timers, collect_non_loss_data, enable_autocast)
send_forward(output_tensor, send_tensor_shapes, timers=timers)
# Decide to checkpoint all layers' activations of the current micro-batch
if max_outstanding_backprops is not None:
checkpoint_activations_microbatch = (
i % max_outstanding_backprops
>= config.num_microbatches_with_partial_activation_checkpoints
)
else:
checkpoint_activations_microbatch = None
input_tensor = recv_forward(recv_tensor_shapes, config)
output_tensor = forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
)
send_forward(output_tensor, send_tensor_shapes, config)
if not forward_only:
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
deallocate_output_tensor(output_tensor[0])
deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
input_tensor = recv_forward(recv_tensor_shapes, dtype, timers=timers)
input_tensor = recv_forward(recv_tensor_shapes, config)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
last_iteration = i == (num_microbatches_remaining - 1)
output_tensor = forward_step(forward_step_func, data_iterator, model, num_microbatches,
input_tensor, forward_data_store,
timers, collect_non_loss_data, enable_autocast)
# Decide to checkpoint all layers' activations of the current micro-batch
if max_outstanding_backprops is not None:
checkpoint_activations_microbatch = (
(i + num_warmup_microbatches) % max_outstanding_backprops
) >= config.num_microbatches_with_partial_activation_checkpoints
else:
checkpoint_activations_microbatch = None
output_tensor = forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
)
if forward_only:
send_forward(output_tensor, send_tensor_shapes, timers=timers)
send_forward(output_tensor, send_tensor_shapes, config)
if not last_iteration:
input_tensor = recv_forward(recv_tensor_shapes, dtype, timers=timers)
input_tensor = recv_forward(recv_tensor_shapes, config)
else:
output_tensor_grad = \
send_forward_recv_backward(output_tensor,
send_tensor_shapes, dtype,
timers=timers)
output_tensor_grad = send_forward_recv_backward(
output_tensor, send_tensor_shapes, config
)
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
deallocate_output_tensor(output_tensor[0])
deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs)
# Pop input_tensor and output_tensor from the start of the list for
# the backward pass.
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
input_tensor_grad = \
backward_step(grad_scaler, input_tensor, output_tensor,
output_tensor_grad, model_type, timers)
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad, model_type, config
)
if last_iteration:
input_tensor = None
send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
send_backward(input_tensor_grad, recv_tensor_shapes, config)
else:
input_tensor = \
send_backward_recv_forward(
input_tensor_grad, recv_tensor_shapes, dtype, timers=timers)
input_tensor = send_backward_recv_forward(
input_tensor_grad, recv_tensor_shapes, config
)
# Run cooldown backward passes.
if not forward_only:
for i in range(num_warmup_microbatches):
# Enable async grad reduction in the last backward pass
# Note: If grad sync function is provided, only enable
# async grad reduction in first pipeline stage. Other
# pipeline stages do grad reduction during pipeline
# bubble.
if i == num_warmup_microbatches - 1:
if config.grad_sync_func is None or rank == 0:
enable_grad_sync()
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = recv_backward(send_tensor_shapes, dtype, timers=timers)
output_tensor_grad = recv_backward(send_tensor_shapes, config)
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad, model_type, config
)
input_tensor_grad = \
backward_step(grad_scaler, input_tensor, output_tensor,
output_tensor_grad, model_type, timers)
send_backward(input_tensor_grad, recv_tensor_shapes, config)
send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
# Launch any remaining grad reductions
if no_sync_context is not None:
enable_grad_sync()
if config.grad_sync_func is not None:
config.grad_sync_func(model.parameters())
return forward_data_store
torch
\ No newline at end of file
from .cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data
from .layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes,
linear_with_grad_accumulation_and_async_allreduce,
param_is_not_tensor_parallel_duplicate,
linear_with_grad_accumulation_and_async_allreduce
set_defaults_if_not_set_tensor_model_parallel_attributes,
set_tensor_model_parallel_attributes,
)
from .mappings import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
scatter_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region,
scatter_to_sequence_parallel_region,
scatter_to_tensor_model_parallel_region,
)
from .random import (
checkpoint,
get_cuda_rng_tracker,
model_parallel_cuda_manual_seed,
)
from .random import checkpoint, get_cuda_rng_tracker, model_parallel_cuda_manual_seed
from .utils import (
gather_split_1d_tensor,
split_tensor_along_last_dim,
split_tensor_into_1d_equal_chunks,
gather_split_1d_tensor,
)
__all__ = [
......@@ -38,7 +29,7 @@ __all__ = [
"vocab_parallel_cross_entropy",
# data.py
"broadcast_data",
#layers.py
# layers.py
"ColumnParallelLinear",
"RowParallelLinear",
"VocabParallelEmbedding",
......@@ -51,7 +42,7 @@ __all__ = [
"copy_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region",
"gather_from_sequence_parallel_region",
# "reduce_from_tensor_model_parallel_region",
# "reduce_from_tensor_model_parallel_region",
"scatter_to_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region",
# random.py
......
......@@ -5,22 +5,21 @@ import torch
from megatron.core.parallel_state import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
get_tensor_model_parallel_world_size,
)
from .utils import VocabUtility
class _VocabParallelCrossEntropy(torch.autograd.Function):
@staticmethod
def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0):
# Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max,
op=torch.distributed.ReduceOp.MAX,
group=get_tensor_model_parallel_group())
torch.distributed.all_reduce(
logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group()
)
# Subtract the maximum value.
vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1)
......@@ -29,8 +28,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
vocab_start_index, vocab_end_index = get_vocab_range(
partition_vocab_size, rank, world_size)
vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size)
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
......@@ -42,24 +40,27 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.size()[0],
device=logits_2d.device)
arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits,
op=torch.distributed.ReduceOp.SUM,
group=get_tensor_model_parallel_group())
torch.distributed.all_reduce(
predicted_logits,
op=torch.distributed.ReduceOp.SUM,
group=get_tensor_model_parallel_group(),
)
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = vocab_parallel_logits
torch.exp(vocab_parallel_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
torch.distributed.all_reduce(sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=get_tensor_model_parallel_group())
torch.distributed.all_reduce(
sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=get_tensor_model_parallel_group(),
)
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
......@@ -87,7 +88,6 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs
ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
# Store softmax, target-mask and masked-target for backward pass.
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
......@@ -108,8 +108,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0],
device=grad_2d.device)
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
softmax_update = 1.0 - target_mask.view(-1).float()
......
......@@ -8,15 +8,16 @@ from megatron.core.parallel_state import (
get_tensor_model_parallel_src_rank,
)
_MAX_DATA_DIM = 5
def _check_data_types(keys, data, target_dtype):
"""Check that all the keys have the same target data type."""
for key in keys:
assert data[key].dtype == target_dtype, '{} has data type {} which '\
assert data[key].dtype == target_dtype, (
'{} has data type {} which '
'is different than {}'.format(key, data[key].dtype, target_dtype)
)
def _build_key_size_numel_dictionaries(keys, data):
......@@ -36,8 +37,9 @@ def _build_key_size_numel_dictionaries(keys, data):
# Move to GPU and broadcast.
sizes_cuda = torch.cuda.LongTensor(sizes)
torch.distributed.broadcast(sizes_cuda, get_tensor_model_parallel_src_rank(),
group=get_tensor_model_parallel_group())
torch.distributed.broadcast(
sizes_cuda, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group()
)
# Move back to cpu and unpack.
sizes_cpu = sizes_cuda.cpu()
......@@ -74,24 +76,21 @@ def broadcast_data(keys, data, datatype):
"""
# Build (key, size) and (key, number of elements) dictionaries along
# with the total number of elements on all ranks.
key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys,
data)
key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data)
# Pack on rank zero.
if get_tensor_model_parallel_rank() == 0:
# Check that all keys have the same data type.
_check_data_types(keys, data, datatype)
# Flatten the data associated with the keys
flatten_data = torch.cat(
[data[key].contiguous().view(-1) for key in keys], dim=0).cuda()
flatten_data = torch.cat([data[key].contiguous().view(-1) for key in keys], dim=0).cuda()
else:
flatten_data = torch.empty(total_numel,
device=torch.cuda.current_device(),
dtype=datatype)
flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype)
# Broadcast
torch.distributed.broadcast(flatten_data, get_tensor_model_parallel_src_rank(),
group=get_tensor_model_parallel_group())
torch.distributed.broadcast(
flatten_data, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group()
)
# Unpack
output = {}
......
......@@ -5,37 +5,33 @@
import math
import os
from typing import Optional
import warnings
from typing import Callable, Optional
import torch
import torch.nn.functional as F
import torch.nn.init as init
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.parameter import Parameter
from torch.cuda.amp import custom_fwd, custom_bwd
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.parallel_state import (
get_global_memory_buffer,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group,
get_global_memory_buffer,
)
from .mappings import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
gather_from_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region,
scatter_to_tensor_model_parallel_region,
)
from .random import get_cuda_rng_tracker
from .utils import (
divide,
split_tensor_along_last_dim,
VocabUtility,
)
from .utils import VocabUtility, divide, split_tensor_along_last_dim
_grad_accum_fusion_available = True
try:
......@@ -43,14 +39,17 @@ try:
except ImportError:
_grad_accum_fusion_available = False
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim': -1,
'partition_stride': 1}
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
'tensor_model_parallel': False,
'partition_dim': -1,
'partition_stride': 1,
}
def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, 'tensor_model_parallel') and
param.tensor_model_parallel) or (
get_tensor_model_parallel_rank() == 0)
return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or (
get_tensor_model_parallel_rank() == 0
)
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
......@@ -67,6 +66,7 @@ def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
def maybe_set(attribute, value):
if not hasattr(tensor, attribute):
setattr(tensor, attribute, value)
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])
......@@ -74,51 +74,52 @@ def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
def maybe_copy(attribute):
if hasattr(source_tensor, attribute):
setattr(destination_tensor, attribute,
getattr(source_tensor, attribute))
setattr(destination_tensor, attribute, getattr(source_tensor, attribute))
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
maybe_copy(attribute)
def _initialize_affine_weight_gpu(weight, init_method,
partition_dim, stride=1):
def _initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU."""
set_tensor_model_parallel_attributes(tensor=weight,
is_parallel=True,
dim=partition_dim,
stride=stride)
set_tensor_model_parallel_attributes(
tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
)
with get_cuda_rng_tracker().fork():
init_method(weight)
def _initialize_affine_weight_cpu(weight, output_size, input_size,
per_partition_size, partition_dim,
init_method, stride=1,
return_master_weight=False,
*, params_dtype=torch.float32):
def _initialize_affine_weight_cpu(
weight,
output_size,
input_size,
per_partition_size,
partition_dim,
init_method,
stride=1,
return_master_weight=False,
*,
params_dtype=torch.float32,
):
"""Initialize affine weight for model parallel.
Build the master weight on all processes and scatter
the relevant chunk."""
set_tensor_model_parallel_attributes(tensor=weight,
is_parallel=True,
dim=partition_dim,
stride=stride)
set_tensor_model_parallel_attributes(
tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
)
# Initialize master weight
master_weight = torch.empty(output_size, input_size,
dtype=torch.float,
requires_grad=False)
master_weight = torch.empty(output_size, input_size, dtype=torch.float, requires_grad=False)
init_method(master_weight)
master_weight = master_weight.to(dtype=params_dtype)
# Split and copy
per_partition_per_stride_size = divide(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size,
dim=partition_dim)
weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim)
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size]
......@@ -140,17 +141,17 @@ class VocabParallelEmbedding(torch.nn.Module):
embedding_dim: size of hidden state.
Keyword Arguments:
init_method: method to initialize weights.
params_dtype
use_cpu_initialization
perform_initialization
config: A megatron.core.ModelParallelConfig object
"""
def __init__(self, num_embeddings: int, embedding_dim: int, *,
init_method=init.xavier_normal_,
params_dtype: torch.dtype=torch.float32,
use_cpu_initialization: bool=False,
perform_initialization: bool=True):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
*,
init_method: Callable,
config: ModelParallelConfig,
):
super(VocabParallelEmbedding, self).__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
......@@ -158,52 +159,68 @@ class VocabParallelEmbedding(torch.nn.Module):
# Set the detauls for compatibility.
self.padding_idx = None
self.max_norm = None
self.norm_type = 2.
self.norm_type = 2.0
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = \
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_tensor_model_parallel_rank(),
self.tensor_model_parallel_size)
self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index
(
self.vocab_start_index,
self.vocab_end_index,
) = VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_tensor_model_parallel_rank(), self.tensor_model_parallel_size
)
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
# Allocate weights and initialize.
if use_cpu_initialization:
self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim,
dtype=params_dtype))
if perform_initialization:
if config.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition, self.embedding_dim, dtype=config.params_dtype
)
)
if config.perform_initialization:
_initialize_affine_weight_cpu(
self.weight, self.num_embeddings, self.embedding_dim,
self.num_embeddings_per_partition, 0, init_method,
params_dtype=params_dtype)
self.weight,
self.num_embeddings,
self.embedding_dim,
self.num_embeddings_per_partition,
0,
init_method,
params_dtype=config.params_dtype,
)
else:
self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim,
device=torch.cuda.current_device(), dtype=params_dtype))
if perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=1)
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1)
def forward(self, input_):
if self.tensor_model_parallel_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | \
(input_ >= self.vocab_end_index)
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(masked_input, self.weight,
self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq,
self.sparse)
output_parallel = F.embedding(
masked_input,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
# Mask the output embedding.
if self.tensor_model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0
......@@ -212,13 +229,97 @@ class VocabParallelEmbedding(torch.nn.Module):
return output
class LinearWithFrozenWeight(torch.autograd.Function):
"""Linear operator that does not calculate gradient for weight.
This op and LinearWithGradAccumulationAndAsyncCommunication performs
mathematically-identical forward and DGRAD.
Conceptually this op is the same as torch.nn.functional.linear with
weight.requires_grad==False, but in experiments they are not identical
mathematically. """
@staticmethod
@custom_fwd
def forward(
ctx, input, weight, bias,
):
ctx.save_for_backward(weight)
output = torch.matmul(input, weight.t())
if bias is not None:
output = output + bias
return output
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
(weight,) = ctx.saved_tensors
grad_input = grad_output.matmul(weight)
return grad_input, None, None
def linear_with_frozen_weight(
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
gradient_accumulation_fusion: bool,
async_grad_allreduce: bool,
sequence_parallel: bool,
) -> torch.Tensor:
"""Linear layer execution with weight.requires_grad == False.
This function handles linear layers with weight frozen (untrainable).
In the forward, it only saves weight and does not save input activations.
In the backward, it does not perform weight gradient calculation, or
weight gradient allreduce.
Arguments:
input (torch.Tensor required): input like torch.nn.functional.linear
weight (torch.Tensor required): weight like torch.nn.functional.linear
bias (torch.Tensor optional): bias like torch.nn.functional.linear
gradient_accumulation_fusion (bool required): dummy argument, used to
keep the API unified between all forward implementation functions.
async_grad_allreduce (bool required): dummy argument, used to
keep the API unified between all forward implementation functions.
sequence_parallel (bool required): Indicates that sequence
parallelism is used and thus in the forward pass the input is
all gathered, and the backward pass the input gradients are
reduce scattered.
"""
if sequence_parallel:
input = gather_from_sequence_parallel_region(input, tensor_parallel_output_grad=True)
else:
input = input
args = [
input,
weight,
bias,
]
return LinearWithFrozenWeight.apply(*args)
class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
"""See linear_with_grad_accumulation_and_async_allreduce"""
@staticmethod
@custom_fwd
def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
async_grad_allreduce, sequence_parallel):
def forward(
ctx,
input,
weight,
bias,
gradient_accumulation_fusion,
async_grad_allreduce,
sequence_parallel,
):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
......@@ -230,12 +331,10 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = \
get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
torch.distributed._all_gather_base(
all_gather_buffer,
input,
group=get_tensor_model_parallel_group())
all_gather_buffer, input, group=get_tensor_model_parallel_group()
)
total_input = all_gather_buffer
else:
total_input = input
......@@ -256,12 +355,10 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = \
get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
handle = torch.distributed._all_gather_base(
all_gather_buffer,
input,
group=get_tensor_model_parallel_group(), async_op=True)
all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the input gradient computation
......@@ -273,43 +370,49 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
if ctx.sequence_parallel:
handle.wait()
# Doing gather + slicing during the NeMo forward pass can make this tensor
# not be contiguous. PyTorch only checks if the tensor is contiguous, and only
# clones it if it's not contiguous:
# Doing gather + slicing during the NeMo forward pass can make this tensor
# not be contiguous. PyTorch only checks if the tensor is contiguous, and only
# clones it if it's not contiguous:
# https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1],
grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1],
total_input.shape[2])
grad_output = grad_output.view(
grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
)
total_input = total_input.view(
total_input.shape[0] * total_input.shape[1], total_input.shape[2]
)
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True)
grad_input, group=get_tensor_model_parallel_group(), async_op=True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# all-reduce is scheduled before the weight gradient computation
if ctx.sequence_parallel:
assert not ctx.async_grad_allreduce
dim_size = list(input.size())
sub_grad_input = torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
sub_grad_input = torch.empty(
dim_size, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False
)
# reduce_scatter
handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input,
group=get_tensor_model_parallel_group(),
async_op=True)
handle = torch.distributed._reduce_scatter_base(
sub_grad_input, grad_input, group=get_tensor_model_parallel_group(), async_op=True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# reduce scatter is scheduled before the weight gradient computation
if ctx.gradient_accumulation_fusion:
if weight.main_grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad)
elif weight.main_grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, weight.main_grad)
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
total_input, grad_output, weight.main_grad
)
elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
total_input, grad_output, weight.main_grad
)
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
grad_weight = None
......@@ -326,13 +429,14 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None, None
def linear_with_grad_accumulation_and_async_allreduce(
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
gradient_accumulation_fusion: bool,
async_grad_allreduce: bool,
sequence_parallel_enabled: bool,
sequence_parallel: bool,
) -> torch.Tensor:
"""Linear layer execution with asynchronous communication and
gradient accumulation fusion in backprop.
......@@ -378,10 +482,10 @@ def linear_with_grad_accumulation_and_async_allreduce(
async_grad_allreduce (bool required): Do the allreduce of input
gradients asyncronously with the computation of weight
gradients. If sequence_parallel_enabled is True, this must be
gradients. If sequence_parallel is True, this must be
False, as no all reduce is performed.
sequence_parallel_enabled (bool required): Indicates that sequence
sequence_parallel (bool required): Indicates that sequence
parallelism is used and thus in the forward pass the input is
all gathered, and the backward pass the input gradients are
reduce scattered.
......@@ -392,29 +496,33 @@ def linear_with_grad_accumulation_and_async_allreduce(
bias,
gradient_accumulation_fusion,
async_grad_allreduce,
sequence_parallel_enabled,
sequence_parallel,
]
if not linear_with_grad_accumulation_and_async_allreduce.warned:
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if sequence_parallel_enabled:
if sequence_parallel:
warnings.warn(
"When using sequence parallelism it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup")
"maximum speedup"
)
linear_with_grad_accumulation_and_async_allreduce.warned = True
if async_grad_allreduce:
warnings.warn(
"When using async grad allreduce it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup")
"maximum speedup"
)
linear_with_grad_accumulation_and_async_allreduce.warned = True
return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
linear_with_grad_accumulation_and_async_allreduce.warned = False
class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism.
......@@ -436,28 +544,34 @@ class ColumnParallelLinear(torch.nn.Module):
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
async_tensor_model_parallel_allreduce:
params_dtype:
use_cpu_initialization:
gradient_accumulation_fusion:
sequence_parallel_enabled:
skip_bias_add: If True, do not add the bias term, instead
return it to be added by the caller. This
enables performance optimations where bias can
be fused with other elementwise operations.
skip_weight_param_allocation: If True, weight parameter is not allocated and must be passed
as a keyword argument `weight` during the forward pass. Note
that this does not affect bias, which will be allocated if
bias is True. Defaults to False.
config: ModelParallelConfig object
"""
def __init__(self, input_size, output_size, *,
bias=True, gather_output=True,
init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False,
async_tensor_model_parallel_allreduce=True,
params_dtype=torch.float32,
use_cpu_initialization=False,
perform_initialization=True,
gradient_accumulation_fusion=False,
sequence_parallel_enabled: bool = False,
):
def __init__(
self,
input_size,
output_size,
*,
config: ModelParallelConfig,
init_method: Callable,
bias=True,
gather_output=False,
stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False,
skip_weight_param_allocation: bool = False,
):
super(ColumnParallelLinear, self).__init__()
# Keep input parameters
......@@ -468,105 +582,151 @@ class ColumnParallelLinear(torch.nn.Module):
world_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size)
self.skip_bias_add = skip_bias_add
self.config = config
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
if use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size_per_partition,
self.input_size,
dtype=params_dtype))
if perform_initialization:
self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size,
self.output_size_per_partition, 0, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
if not skip_weight_param_allocation:
if config.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.output_size_per_partition, self.input_size, dtype=config.params_dtype
)
)
if config.perform_initialization:
self.master_weight = _initialize_affine_weight_cpu(
self.weight,
self.output_size,
self.input_size,
self.output_size_per_partition,
0,
init_method,
stride=stride,
return_master_weight=keep_master_weight_for_test,
)
else:
self.weight = Parameter(
torch.empty(
self.output_size_per_partition,
self.input_size,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(
self.weight, init_method, partition_dim=0, stride=stride
)
else:
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
device=torch.cuda.current_device(), dtype=params_dtype))
if perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=stride)
self.weight = None
if bias:
if use_cpu_initialization:
self.bias = Parameter(torch.empty(
self.output_size_per_partition, dtype=params_dtype))
if config.use_cpu_initialization:
self.bias = Parameter(
torch.empty(self.output_size_per_partition, dtype=config.params_dtype)
)
else:
self.bias = Parameter(torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype))
self.bias = Parameter(
torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
if config.perform_initialization:
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
self.async_tensor_model_parallel_allreduce = (
async_tensor_model_parallel_allreduce and
world_size > 1)
if sequence_parallel_enabled:
if world_size <= 1:
warnings.warn(
f"`sequence_parallel_enabled` is set to `True`, but tensor model parallel size is {world_size}. "
f"Disabling sequence parallel."
)
sequence_parallel_enabled = False
self.sequence_parallel_enabled = sequence_parallel_enabled
config.async_tensor_model_parallel_allreduce and world_size > 1
)
if gradient_accumulation_fusion:
if not _grad_accum_fusion_available:
raise RuntimeError(
"ColumnParallelLinear was called with gradient_accumulation_fusion set "
"to True but the custom CUDA extension fused_weight_gradient_mlp_cuda "
"module is not found. To use gradient_accumulation_fusion you must "
"install APEX with --cpp_ext and --cuda_ext. For example: "
"pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" "
"Note that the extension requires CUDA>=11. Otherwise, you must turn off "
"gradient accumulation fusion."
)
self.gradient_accumulation_fusion = gradient_accumulation_fusion
self.sequence_parallel = config.sequence_parallel
if self.sequence_parallel and world_size <= 1:
warnings.warn(
f"`sequence_parallel` is set to `True`, but tensor model parallel size is {world_size}. "
f"Disabling sequence parallel."
)
self.sequence_parallel = False
if config.gradient_accumulation_fusion and not _grad_accum_fusion_available:
raise RuntimeError(
"ColumnParallelLinear was called with gradient_accumulation_fusion set "
"to True but the custom CUDA extension fused_weight_gradient_mlp_cuda "
"module is not found. To use gradient_accumulation_fusion you must "
"install APEX with --cpp_ext and --cuda_ext. For example: "
"pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" "
"Note that the extension requires CUDA>=11. Otherwise, you must turn off "
"gradient accumulation fusion."
)
self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
if self.async_tensor_model_parallel_allreduce and self.sequence_parallel_enabled:
if self.async_tensor_model_parallel_allreduce and self.sequence_parallel:
raise RuntimeError(
"`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` "
"`async_tensor_model_parallel_allreduce` and `sequence_parallel` "
"cannot be enabled at the same time."
)
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
def forward(self, input_):
def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None):
"""Forward of ColumnParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
weight (optional): weight tensor to use, compulsory when
skip_weight_param_allocation is True.
Returns:
- output
- bias
"""
if weight is None:
if self.weight is None:
raise RuntimeError(
"weight was not supplied to ColumnParallelLinear forward pass "
"and skip_weight_param_allocation is True."
)
weight = self.weight
else:
# Check the weight passed in is the correct shape
expected_shape = (self.output_size_per_partition, self.input_size)
if weight.shape != expected_shape:
raise RuntimeError(
f"supplied weight's shape is {tuple(weight.shape)}, "
f"not {expected_shape} as expected"
)
bias = self.bias if not self.skip_bias_add else None
if self.async_tensor_model_parallel_allreduce or \
self.sequence_parallel_enabled:
if self.async_tensor_model_parallel_allreduce or self.sequence_parallel:
input_parallel = input_
else:
input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = linear_with_grad_accumulation_and_async_allreduce(
if not weight.requires_grad:
self._forward_impl = linear_with_frozen_weight
else:
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
output_parallel = self._forward_impl(
input=input_parallel,
weight=self.weight,
weight=weight,
bias=bias,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=self.async_tensor_model_parallel_allreduce,
sequence_parallel_enabled=self.sequence_parallel_enabled,
sequence_parallel=self.sequence_parallel,
)
if self.gather_output:
# All-gather across the partitions.
assert not self.sequence_parallel_enabled
assert not self.sequence_parallel
output = gather_from_tensor_model_parallel_region(output_parallel)
else:
output = output_parallel
......@@ -601,27 +761,27 @@ class RowParallelLinear(torch.nn.Module):
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimization where bias
can be fused with other elementwise operations. We skip
adding bias but instead return it.
params_dtype:
use_cpu_initialization:
perform_initialization:
gradient_accumulation_fusion:
sequence_parallel_enabled:
skip_bias_add: If True, do not add the bias term, instead
return it to be added by the caller. This
enables performance optimations where bias can
be fused with other elementwise operations.
config: ModelParallelConfig object
"""
def __init__(self, input_size, output_size, *,
bias=True, input_is_parallel=False,
init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False,
params_dtype=torch.float32,
use_cpu_initialization=False,
perform_initialization=True,
gradient_accumulation_fusion=False,
sequence_parallel_enabled: bool = False,
):
def __init__(
self,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool = True,
input_is_parallel: bool = False,
stride: int = 1,
keep_master_weight_for_test: bool = False,
skip_bias_add: bool = False,
):
super(RowParallelLinear, self).__init__()
# Keep input parameters
......@@ -632,49 +792,68 @@ class RowParallelLinear(torch.nn.Module):
world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add
self.gradient_accumulation_fusion = gradient_accumulation_fusion
self.sequence_parallel_enabled = sequence_parallel_enabled
if self.sequence_parallel_enabled and not self.input_is_parallel:
raise RuntimeError("To enable `sequence_parallel_enabled`, `input_is_parallel` must be `True`")
self.config = config
self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
self.sequence_parallel = config.sequence_parallel
if self.sequence_parallel and not self.input_is_parallel:
raise RuntimeError("To enable `sequence_parallel`, `input_is_parallel` must be `True`")
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
if use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size,
self.input_size_per_partition,
dtype=params_dtype))
if perform_initialization:
if config.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.output_size, self.input_size_per_partition, dtype=config.params_dtype
)
)
if config.perform_initialization:
self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size,
self.input_size_per_partition, 1, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test,
params_dtype=params_dtype)
self.weight,
self.output_size,
self.input_size,
self.input_size_per_partition,
1,
init_method,
stride=stride,
return_master_weight=keep_master_weight_for_test,
params_dtype=config.params_dtype,
)
else:
self.weight = Parameter(torch.empty(
self.output_size, self.input_size_per_partition,
device=torch.cuda.current_device(), dtype=params_dtype))
if perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=1, stride=stride)
self.weight = Parameter(
torch.empty(
self.output_size,
self.input_size_per_partition,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(
self.weight, init_method, partition_dim=1, stride=stride
)
if bias:
if use_cpu_initialization:
self.bias = Parameter(torch.empty(self.output_size,
dtype=params_dtype))
if config.use_cpu_initialization:
self.bias = Parameter(torch.empty(self.output_size, dtype=config.params_dtype))
else:
self.bias = Parameter(torch.empty(
self.output_size, device=torch.cuda.current_device(),
dtype=params_dtype))
setattr(self.bias, 'sequence_parallel', sequence_parallel_enabled)
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
self.bias = Parameter(
torch.empty(
self.output_size,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
if config.perform_initialization:
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
def forward(self, input_):
"""Forward of RowParallelLinear
......@@ -690,20 +869,24 @@ class RowParallelLinear(torch.nn.Module):
if self.input_is_parallel:
input_parallel = input_
else:
assert not self.sequence_parallel_enabled
assert not self.sequence_parallel
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = linear_with_grad_accumulation_and_async_allreduce(
if not self.weight.requires_grad:
self._forward_impl = linear_with_frozen_weight
else:
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
output_parallel = self._forward_impl(
input=input_parallel,
weight=self.weight,
bias=None,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=False,
sequence_parallel_enabled=False,
sequence_parallel=False,
)
# All-reduce across all the partitions.
if self.sequence_parallel_enabled:
if self.sequence_parallel:
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
......
......@@ -3,10 +3,11 @@
import torch
from megatron.core.parallel_state import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group,
)
from .utils import split_tensor_along_last_dim
......@@ -14,7 +15,7 @@ def _reduce(input_):
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size()==1:
if get_tensor_model_parallel_world_size() == 1:
return input_
# All-reduce.
......@@ -53,13 +54,14 @@ def _split_along_first_dim(input_):
# Split along first dimension.
dim_size = input_.size()[0]
assert dim_size % world_size == 0, \
"First dimension of the tensor should be divisible by tensor parallel size"
assert (
dim_size % world_size == 0
), "First dimension of the tensor should be divisible by tensor parallel size"
local_dim_size = dim_size // world_size
rank = get_tensor_model_parallel_rank()
dim_offset = rank * local_dim_size
output = input_[dim_offset:dim_offset+local_dim_size].contiguous()
output = input_[dim_offset : dim_offset + local_dim_size].contiguous()
return output
......@@ -97,13 +99,14 @@ def _gather_along_first_dim(input_):
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device())
torch.distributed._all_gather_base(output, input_.contiguous(),
group=get_tensor_model_parallel_group())
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
torch.distributed._all_gather_base(
output, input_.contiguous(), group=get_tensor_model_parallel_group()
)
return output
def _reduce_scatter_along_first_dim(input_):
"""Reduce-scatter the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size()
......@@ -112,15 +115,16 @@ def _reduce_scatter_along_first_dim(input_):
return input_
dim_size = list(input_.size())
assert dim_size[0] % world_size == 0, \
"First dimension of the tensor should be divisible by tensor parallel size"
assert (
dim_size[0] % world_size == 0
), "First dimension of the tensor should be divisible by tensor parallel size"
dim_size[0] = dim_size[0] // world_size
output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device())
torch.distributed._reduce_scatter_base(output, input_.contiguous(),
group=get_tensor_model_parallel_group())
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
torch.distributed._reduce_scatter_base(
output, input_.contiguous(), group=get_tensor_model_parallel_group()
)
return output
......@@ -130,7 +134,7 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_):
return input_
......@@ -146,7 +150,7 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
@staticmethod
def symbolic(graph, input_):
return _reduce(input_)
@staticmethod
def forward(ctx, input_):
return _reduce(input_)
......@@ -178,7 +182,7 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
@staticmethod
def symbolic(graph, input_):
return _gather_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _gather_along_last_dim(input_)
......@@ -205,12 +209,12 @@ class _ScatterToSequenceParallelRegion(torch.autograd.Function):
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatinate."""
"""Gather the input from sequence parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_, tensor_parallel_output_grad=True):
return _gather_along_first_dim(input_)
@staticmethod
def forward(ctx, input_, tensor_parallel_output_grad=True):
ctx.tensor_parallel_output_grad = tensor_parallel_output_grad
......@@ -221,8 +225,8 @@ class _GatherFromSequenceParallelRegion(torch.autograd.Function):
tensor_parallel_output_grad = ctx.tensor_parallel_output_grad
# If the computation graph after the gather operation is
# in the tensor parallel mode, output gradients need to reduce
# scattered and whereas if the computation is duplicated,
# in the tensor parallel mode, output gradients need to reduce
# scattered and whereas if the computation is duplicated,
# output gradients need to be scattered.
if tensor_parallel_output_grad:
return _reduce_scatter_along_first_dim(grad_output), None
......@@ -236,7 +240,7 @@ class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
@staticmethod
def symbolic(graph, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _reduce_scatter_along_first_dim(input_)
......@@ -250,6 +254,7 @@ class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
# Helper functions.
# -----------------
def copy_to_tensor_model_parallel_region(input_):
return _CopyToModelParallelRegion.apply(input_)
......@@ -276,4 +281,3 @@ def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=Tru
def reduce_scatter_to_sequence_parallel_region(input_):
return _ReduceScatterToSequenceParallelRegion.apply(input_)
......@@ -7,7 +7,8 @@ import contextlib
import torch
from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager
from torch.cuda import _lazy_call
from torch.cuda import device as device_ctx_manager
from torch.utils.checkpoint import detach_variable
from megatron.core.parallel_state import (
......@@ -16,14 +17,10 @@ from megatron.core.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from .utils import (
split_tensor_into_1d_equal_chunks,
gather_split_1d_tensor,
)
from megatron.core.utils import safely_set_viewless_tensor_data
from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
......@@ -42,6 +39,7 @@ def _set_cuda_rng_state(new_state, device=-1):
def cb():
with device_ctx_manager(device):
_C._cuda_setRNGState(new_state)
else:
# newer PyTorch
if device == -1:
......@@ -61,7 +59,6 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call(cb)
class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
......@@ -169,8 +166,7 @@ def model_parallel_cuda_manual_seed(seed):
# Set the default state.
torch.cuda.manual_seed(data_parallel_seed)
# and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,
tensor_model_parallel_seed)
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed)
class CheckpointFunction(torch.autograd.Function):
......@@ -180,11 +176,11 @@ class CheckpointFunction(torch.autograd.Function):
2) the states in the model parallel tracker are also properly
tracked/set/reset.
"""
@staticmethod
def forward(ctx, run_function, distribute_saved_activations, *args):
ctx.run_function = run_function
ctx.distribute_saved_activations \
= distribute_saved_activations
ctx.distribute_saved_activations = distribute_saved_activations
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
......@@ -199,8 +195,8 @@ class CheckpointFunction(torch.autograd.Function):
if distribute_saved_activations:
ctx.input_0_shape = args[0].data.shape
safely_set_viewless_tensor_data(
args[0],
split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True))
args[0], split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True)
)
# Store everything.
ctx.save_for_backward(*args)
......@@ -210,13 +206,15 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible")
raise RuntimeError(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
inputs = ctx.saved_tensors
if ctx.distribute_saved_activations:
safely_set_viewless_tensor_data(
inputs[0],
gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape))
inputs[0], gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape)
)
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
......@@ -241,13 +239,11 @@ class CheckpointFunction(torch.autograd.Function):
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
torch.autograd.backward(outputs, args)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
for inp in detached_inputs)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
return (None, None) + grads
def checkpoint(function, distribute_saved_activations, *args):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
return CheckpointFunction.apply(function,
distribute_saved_activations, *args)
return CheckpointFunction.apply(function, distribute_saved_activations, *args)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
from typing import List, Sequence
from megatron.core.utils import divide
import torch
from megatron.core import parallel_state
from megatron.core.utils import divide
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]:
""" Split a tensor along its last dimension.
......@@ -33,6 +33,7 @@ def split_tensor_along_last_dim(
return tensor_list
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
""" Break a tensor into equal 1D chunks across tensor parallel ranks.
......@@ -47,14 +48,16 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
Default is False
"""
partition_size = torch.numel(tensor) // \
parallel_state.get_tensor_model_parallel_world_size()
partition_size = torch.numel(tensor) // parallel_state.get_tensor_model_parallel_world_size()
start_index = partition_size * parallel_state.get_tensor_model_parallel_rank()
end_index = start_index + partition_size
if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
data = torch.empty(
partition_size,
dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
......@@ -70,18 +73,18 @@ def gather_split_1d_tensor(tensor):
Arguments:
tensor: A Tensor or view of this rank's portion of the data.
"""
numel_gathered = torch.numel(tensor) * \
parallel_state.get_tensor_model_parallel_world_size()
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
numel_gathered = torch.numel(tensor) * parallel_state.get_tensor_model_parallel_world_size()
gathered = torch.empty(
numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False
)
# TODO: This API is experimental in pytorch (as of Feb 2022) and
# this might break in future pytorch releases. We chose this API
# as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch.distributed._all_gather_base(gathered, tensor,
group=parallel_state.get_tensor_model_parallel_group())
torch.distributed._all_gather_base(
gathered, tensor, group=parallel_state.get_tensor_model_parallel_group()
)
return gathered
......@@ -101,7 +104,9 @@ class VocabUtility:
return index_f, index_l
@staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Sequence[int]:
def vocab_range_from_global_vocab_size(
global_vocab_size: int, rank: int, world_size: int
) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from .transformer_config import TransformerConfig
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from abc import ABC, abstractmethod
import torch
from megatron.core import parallel_state, tensor_parallel
from megatron.core.models.common.rotary_pos_embedding import apply_rotary_pos_emb
from megatron.core.transformer.custom_layers.transformer_engine import (
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TERowParallelLinear,
)
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import divide
from .enums import AttnMaskType
from .transformer_config import TransformerConfig
class Attention(MegatronModule, ABC):
"""Attention layer abstract class.
This layer only contains common modules required for the "self attn" and
"cross attn" specializations.
"""
def __init__(
self, config: TransformerConfig, layer_number: int = 1, attn_mask_type=AttnMaskType.padding,
):
super().__init__(config=config)
self.config = config
self.layer_number = layer_number
self.attn_mask_type = attn_mask_type
# For normal attention without groups, num_query_groups == num_attention_heads,
# so these two will be the same
self.query_projection_size = self.config.kv_channels * self.config.num_attention_heads
self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = divide(
self.query_projection_size, self.config.num_attention_heads
)
self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size)
self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size)
self.dot_product_attention = TEDotProductAttention(
config=self.config, layer_number=self.layer_number, attn_mask_type=self.attn_mask_type
)
self.checkpoint_dot_product_attention = self.config.recompute_granularity == 'selective'
# Output.
self.linear_proj = TERowParallelLinear(
self.query_projection_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
skip_bias_add=True,
)
def _checkpointed_attention_forward(
self, query, key, value, attention_mask, rotary_pos_emb=None
):
"""Forward method with selective activation checkpointing."""
def custom_forward(*inputs):
query = inputs[0]
key = inputs[1]
value = inputs[2]
attention_mask = inputs[3]
output_ = self.dot_product_attention(query, key, value, attention_mask)
return output_
hidden_states = tensor_parallel.checkpoint(
custom_forward, False, query, key, value, attention_mask, rotary_pos_emb
)
return hidden_states
def _allocate_memory(self, inference_max_sequence_length, batch_size, dtype):
"""Allocate memory to store kv cache during inference."""
return torch.empty(
inference_max_sequence_length,
batch_size,
self.num_query_groups_per_partition,
self.hidden_size_per_attention_head,
dtype=dtype,
device=torch.cuda.current_device(),
)
def _adjust_key_value_for_inference(self, inference_params, key, value, rotary_pos_emb):
"""
Saves the generated key and value tensors to the end of the buffers in inference_params.
Returns the full size keys and values from the provided inference_params, as well as
adjusted rotary_pos_emb.
Returns a tuple: (key, value, rotary_pos_emb)
"""
if inference_params is None:
return key, value, rotary_pos_emb
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
is_first_step = False
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_length = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_length, inf_max_batch_size, key.dtype
)
inference_value_memory = self._allocate_memory(
inf_max_seq_length, inf_max_batch_size, value.dtype
)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory,
inference_value_memory,
)
is_first_step = True
else:
# Get the pre-allocated buffers for this layer
inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[
self.layer_number
]
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key.size(0)
assert sequence_end <= inference_key_memory.size(0)
# Copy key and values.
inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key
inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value
key = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
# adjust the key rotary positional embedding
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
# need to cross check this condition during inference
# if not set_inference_key_value_memory:
if not is_first_step:
# In inference, we compute one token at a time.
# Select the correct positional embedding
# (only the last token in the sequence)
q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
else:
# In the first forward pass of inference,
# we use the entire provided prefix.
# q_pos_emb here has the rope embeddings of the entire
# prefix + to-be-generated output so
# we slice to just the prefix.
q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
rotary_pos_emb = (q_pos_emb, k_pos_emb)
return key, value, rotary_pos_emb
@abstractmethod
def get_query_key_value_tensors(self, hidden_states, key_value_states):
"""
This method needs to be implemented based on whether the derived class
is "self-attn" or "cross-attn".
"""
def forward(
self,
hidden_states,
attention_mask,
key_value_states=None,
inference_params=None,
rotary_pos_emb=None,
):
# hidden_states: [sq, b, h]
# For self attention we just duplicate the rotary_pos_emb if it isn't already
if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = (rotary_pos_emb,) * 2
# =====================
# Query, Key, and Value
# =====================
# Get the query, key and value tensors based on the type of attention -
# self or cross attn.
query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
# ===================================================
# Adjust key, value, and rotary_pos_emb for inference
# ===================================================
key, value, rotary_pos_emb = self._adjust_key_value_for_inference(
inference_params, key, value, rotary_pos_emb
)
# ================================================
# relative positional embedding (rotary embedding)
# ================================================
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
query = apply_rotary_pos_emb(query, q_pos_emb)
key = apply_rotary_pos_emb(key, k_pos_emb)
# TODO, can apply positional embedding to value_layer so it has
# absolute positional embedding.
# otherwise, only relative positional embedding takes effect
# value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)
# ==================================
# core attention computation
# ==================================
# expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn]
# This is a noop for normal attention where ng == np. When using group query attention this
# creates a view that has the keys and values virtually repeated along their dimension to
# match the number of queries.
if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1:
key = key.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2
)
value = value.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2
)
if self.checkpoint_dot_product_attention:
core_attn_out = self._checkpointed_attention_forward(query, key, value, attention_mask)
else:
core_attn_out = self.dot_product_attention(query, key, value, attention_mask)
# =================
# Output. [sq, b, h]
# =================
output, bias = self.linear_proj(core_attn_out)
return output, bias
class SelfAttention(Attention):
"""Self-attention layer class
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def __init__(
self, config: TransformerConfig, layer_number: int = 1, attn_mask_type=AttnMaskType.padding
):
super().__init__(config=config, layer_number=layer_number, attn_mask_type=attn_mask_type)
self.linear_qkv = TELayerNormColumnParallelLinear(
self.config.hidden_size,
self.query_projection_size + 2 * self.kv_projection_size,
config=self.config,
init_method=self.config.init_method,
bias=self.config.add_bias_linear,
skip_bias_add=False,
)
def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
"""
Derives `query`, `key` and `value` tensors from `hidden_states`.
"""
# Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
mixed_qkv, _ = self.linear_qkv(hidden_states)
# [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
new_tensor_shape = mixed_qkv.size()[:-1] + (
self.num_query_groups_per_partition,
(
(self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2)
* self.hidden_size_per_attention_head
),
)
mixed_qkv = mixed_qkv.view(*new_tensor_shape)
# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query, key, value) = torch.split(
mixed_qkv,
[
(
self.num_attention_heads_per_partition
// self.num_query_groups_per_partition
* self.hidden_size_per_attention_head
),
self.hidden_size_per_attention_head,
self.hidden_size_per_attention_head,
],
dim=3,
)
# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)
return query, key, value
class CrossAttention(Attention):
"""Cross-attention layer class
Cross-attention layer takes input with size [s, b, h] and context with size
[s, b, h] and returns output of the same size.
"""
def __init__(
self, config: TransformerConfig, layer_number: int = 1, attn_mask_type=AttnMaskType.padding
):
super().__init__(config=config, layer_number=layer_number, attn_mask_type=attn_mask_type)
if self.config.num_query_groups != self.config.num_attention_heads:
raise ValueError(
f"Group query attention is not currently supported in cross attention."
)
assert self.query_projection_size == self.kv_projection_size
self.linear_q = TELayerNormColumnParallelLinear(
self.config.hidden_size,
self.query_projection_size,
config=self.config,
init_method=self.config.init_method,
bias=self.config.add_bias_linear,
skip_bias_add=False,
)
self.linear_kv = TELayerNormColumnParallelLinear(
self.config.hidden_size,
2 * self.kv_projection_size,
config=self.config,
init_method=self.config.init_method,
bias=self.config.add_bias_linear,
skip_bias_add=False,
)
def get_query_key_value_tensors(self, hidden_states, key_value_states):
"""
Derives `query` tensor from `hidden_states`, and `key`/`value` tensors
from `key_value_states`.
"""
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv, _ = self.linear_kv(key_value_states)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv.size()[:-1] + (
self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head,
)
mixed_kv = mixed_kv.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key, value) = tensor_parallel.split_tensor_along_last_dim(mixed_kv, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query, _ = self.linear_q(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query.size()[:-1] + (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
query = query.view(*new_tensor_shape)
return query, key, value
from importlib.metadata import version
from typing import Callable
import torch
import transformer_engine as te
from pkg_resources import packaging
from megatron.core.parallel_state import get_tensor_model_parallel_group
from megatron.core.tensor_parallel import get_cuda_rng_tracker
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_config import TransformerConfig
def _get_extra_te_kwargs(config: TransformerConfig):
extra_transformer_engine_kwargs = {}
from importlib.metadata import version
from pkg_resources import packaging
te_version = packaging.version.Version(version("transformer-engine"))
if te_version >= packaging.version.Version("0.12.0"):
if config.use_cpu_initialization:
extra_transformer_engine_kwargs["device"] = 'cpu'
else:
extra_transformer_engine_kwargs["device"] = torch.cuda.current_device()
return extra_transformer_engine_kwargs
class TENorm:
"""
A conditional wrapper to initialize an instance of Transformer-Engine's
`LayerNorm` or `RMSNorm` based on input
"""
def __new__(
cls,
config: TransformerConfig,
hidden_size: int,
eps: float = 1e-5,
sequence_parallel: bool = False,
normalization="LayerNorm",
**kwargs
):
zero_centered_gamma = kwargs.get('zero_centered_gamma', False)
if normalization == "LayerNorm":
instance = te.pytorch.LayerNorm(
hidden_size=hidden_size,
eps=eps,
sequence_parallel=sequence_parallel,
zero_centered_gamma=zero_centered_gamma,
**_get_extra_te_kwargs(config),
)
elif normalization == "RMSNorm":
assert hasattr(
te.pytorch, "RMSNorm"
), "Transformer-Engine >= v0.11 required to use this feature"
instance = te.pytorch.RMSNorm(
hidden_size=hidden_size,
eps=eps,
sequence_parallel=sequence_parallel,
zero_centered_gamma=zero_centered_gamma,
**_get_extra_te_kwargs(config),
)
else:
raise Exception('Only LayerNorm and RMSNorm are curently supported')
return instance
class TELinear(te.pytorch.Linear):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
"""
def __init__(
self,
input_size: int,
output_size: int,
config: TransformerConfig,
parallel_mode: str,
init_method: Callable,
*,
bias: bool = True,
skip_bias_add: bool = False,
**kwargs
):
self.config = config
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
super().__init__(
in_features=input_size,
out_features=output_size,
sequence_parallel=self.config.sequence_parallel,
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
tp_group=get_tensor_model_parallel_group(check_initialized=False),
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=get_cuda_rng_tracker,
init_method=init_method,
params_dtype=self.config.params_dtype,
parallel_mode=parallel_mode,
bias=bias,
return_bias=self.te_return_bias,
**_get_extra_te_kwargs(config),
**kwargs,
)
def forward(self, x):
out = super().forward(x)
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if self.te_return_bias:
return out
return out, None
class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear):
"""
Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines
layernorm and linear layers
"""
def __init__(
self,
input_size: int,
output_size: int,
config: TransformerConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
**kwargs
):
self.config = config
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
# Only Transformer-Engine version >= 0.11.0 supports `RMSNorm`
te_version = packaging.version.Version(version("transformer-engine"))
if te_version >= packaging.version.Version("0.11.0"):
kwargs["normalization"] = self.config.normalization
super().__init__(
in_features=input_size,
out_features=output_size,
bias=bias,
sequence_parallel=self.config.sequence_parallel,
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
tp_group=get_tensor_model_parallel_group(check_initialized=False),
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=get_cuda_rng_tracker,
init_method=init_method,
params_dtype=self.config.params_dtype,
parallel_mode="column",
return_bias=self.te_return_bias,
zero_centered_gamma=self.config.layernorm_zero_centered_gamma,
**_get_extra_te_kwargs(config),
**kwargs,
)
def forward(self, x):
out = super().forward(x)
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if self.te_return_bias:
return out
return out, None
class TEColumnParallelLinear(TELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `ColumnParallelLinear` layer.
"""
def __init__(self, input_size: int, output_size: int, config: TransformerConfig, **kwargs):
self.config = config
super().__init__(
input_size=input_size,
output_size=output_size,
config=self.config,
parallel_mode="column",
**kwargs,
)
class TERowParallelLinear(TELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `RowParallelLinear` layer.
"""
def __init__(self, input_size: int, output_size: int, config: TransformerConfig, **kwargs):
self.config = config
super().__init__(
input_size=input_size,
output_size=output_size,
config=self.config,
parallel_mode="row",
**kwargs,
)
class TEDotProductAttention(te.pytorch.DotProductAttention):
"""
Wrapper for the Transformer-Engine's `DotProductAttention` layer that also
has "flash attention" enabled.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
"""
def __init__(
self,
config: TransformerConfig,
layer_number: int = 1,
attn_mask_type: AttnMaskType = AttnMaskType.padding,
**kwargs
):
self.config = config
super().__init__(
num_attention_heads=self.config.num_attention_heads,
kv_channels=self.config.kv_channels,
attention_dropout=self.config.attention_dropout,
layer_number=layer_number,
attn_mask_type=attn_mask_type.name,
sequence_parallel=self.config.sequence_parallel,
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=get_cuda_rng_tracker,
tp_group=get_tensor_model_parallel_group(check_initialized=False),
**kwargs,
)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import math
import torch
from torch import Tensor
from megatron.core import parallel_state, tensor_parallel
from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import attention_mask_func
from megatron.core.utils import divide
class DotProductAttention(MegatronModule):
"""
Region where selective activation recomputation is applied.
This region is memory intensive but less compute intensive which
makes activation checkpointing more efficient for LLMs (20B+).
See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
We use the following notation:
h: hidden size
n: number of attention heads
p: number of tensor model parallel partitions
b: batch size
s: sequence length
"""
def __init__(
self, config: TransformerConfig, layer_number: int = 1, attn_mask_type=AttnMaskType.padding
):
super().__init__(config=config)
self.config: TransformerConfig = config
self.layer_number = max(1, layer_number)
self.attn_mask_type = attn_mask_type
projection_size = self.config.kv_channels * config.num_attention_heads
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = divide(projection_size, world_size)
self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads)
self.num_attention_heads_per_partition = divide(config.num_attention_heads, world_size)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.config.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
input_in_fp16=self.config.fp16,
input_in_bf16=self.config.bf16,
attn_mask_type=self.attn_mask_type,
scaled_masked_softmax_fusion=self.config.masked_softmax_fusion,
mask_func=attention_mask_func,
softmax_in_fp32=self.config.attention_softmax_in_fp32,
scale=coeff,
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(self.config.attention_dropout)
def forward(
self, query_layer: Tensor, key_layer: Tensor, value_layer: Tensor, attention_mask: Tensor
):
# ===================================
# Raw attention scores. [b, n/p, s, s]
# ===================================
# [b, np, sq, sk]
output_size = (
query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0),
)
# [sq, b, np, hn] -> [sq, b * np, hn]
# This will be a simple view when doing normal attention, but in group query attention
# the key and value tensors are repeated to match the queries so you can't use simple strides
# to extract the queries.
query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor(
(output_size[0] * output_size[1], output_size[2], output_size[3]),
query_layer.dtype,
"mpu",
)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
)
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
if not self.config.sequence_parallel:
with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
else:
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (
value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3),
)
# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import enum
# can we get rid of this?
# it's being used in pipeline schedules
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
# class LayerType(enum.Enum):
# encoder = 1
# decoder = 2
class AttnType(enum.Enum):
self_attn = 1
cross_attn = 2
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import torch
class IdentityOp(torch.nn.Module):
"""
This is a placeholder for IdentityOp (NoOp)
"""
def __init__(self, *args, **kwargs):
super(IdentityOp, self).__init__()
def forward(self, x, *args, **kwargs):
return x
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import torch
import torch.nn.functional as F
from megatron.core import tensor_parallel
from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl
from megatron.core.transformer.custom_layers.transformer_engine import (
TELayerNormColumnParallelLinear,
TERowParallelLinear,
)
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
class MLP(MegatronModule):
"""
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
Returns an output and a bias to be added to the output.
If config.add_bias_linear is False, the bias returned is None.
We use the following notation:
h: hidden size
p: number of tensor model parallel partitions
b: batch size
s: sequence length
"""
def __init__(self, config: TransformerConfig):
super().__init__(config=config)
self.config: TransformerConfig = config
# If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf
ffn_hidden_size = self.config.ffn_hidden_size
if self.config.gated_linear_unit:
ffn_hidden_size *= 2
self.linear_fc1 = TELayerNormColumnParallelLinear(
self.config.hidden_size,
ffn_hidden_size,
config=self.config,
init_method=self.config.init_method,
bias=self.config.add_bias_linear,
skip_bias_add=True,
)
if self.config.gated_linear_unit:
def glu(x):
x = torch.chunk(x, 2, dim=-1)
return self.config.activation_func(x[0]) * x[1]
self.activation_func = glu
else:
self.activation_func = self.config.activation_func
self.linear_fc2 = TERowParallelLinear(
self.config.ffn_hidden_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
skip_bias_add=True,
)
def forward(self, hidden_states):
# [s, b, 4 * h/p]
intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states)
if self.config.bias_gelu_fusion:
assert self.config.add_bias_linear is True
assert self.activation_func == F.gelu
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
else:
if bias_parallel is not None:
intermediate_parallel = intermediate_parallel + bias_parallel
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output, output_bias = self.linear_fc2(intermediate_parallel)
return output, output_bias
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment