Commit 42c40715 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'core' into 'main'

Move most of mpu functionality into a new "Megatron core"

See merge request ADLR/megatron-lm!462
parents fdc801ed d3a416cd
[html]
directory = coverage
[run]
data_file = .coverage_$LOCAL_RANK
__pycache__ __pycache__
*.so
build
.coverage_*
*.egg-info
image: gitlab-master.nvidia.com/dl/dgx/pytorch:21.12-py3-devel image: gitlab-master.nvidia.com/dl/dgx/pytorch:21.12-py3-devel
test: test:
tags:
- docker_gpu_enabled
script: script:
- pytest --junitxml=report.xml tests - torchrun --nproc_per_node=8 -m pytest --cov-report=term --cov-report=html --cov=megatron/core tests/
coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/'
artifacts: artifacts:
when: always paths:
reports: - coverage
junit: report.xml expire_in: 30 days
\ No newline at end of file
...@@ -10,7 +10,6 @@ from .global_vars import get_tokenizer ...@@ -10,7 +10,6 @@ from .global_vars import get_tokenizer
from .global_vars import get_tensorboard_writer from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume from .global_vars import get_adlr_autoresume
from .global_vars import get_timers from .global_vars import get_timers
from .global_vars import get_global_memory_buffer
from .initialize import initialize_megatron from .initialize import initialize_megatron
from .utils import (print_rank_0, from .utils import (print_rank_0,
......
...@@ -168,14 +168,6 @@ def validate_args(args, defaults={}): ...@@ -168,14 +168,6 @@ def validate_args(args, defaults={}):
if args.accumulate_allreduce_grads_in_fp32: if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local' assert args.DDP_impl == 'local'
assert args.use_contiguous_buffers_in_local_ddp assert args.use_contiguous_buffers_in_local_ddp
else:
if args.gradient_accumulation_fusion:
args.gradient_accumulation_fusion = False
if args.rank == 0:
print('Gradient accumulation fusion to linear layer weight '
'gradient computation is supported only with fp32 '
'gradient accumulation. Setting gradient_accumulation_fusion '
'to False', flush=True)
# If we use the distributed optimizer, we need to have local DDP # If we use the distributed optimizer, we need to have local DDP
# and we should make sure use-contiguous-buffers-in-local-ddp is on. # and we should make sure use-contiguous-buffers-in-local-ddp is on.
...@@ -321,6 +313,18 @@ def validate_args(args, defaults={}): ...@@ -321,6 +313,18 @@ def validate_args(args, defaults={}):
if args.sequence_parallel: if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False args.async_tensor_model_parallel_allreduce = False
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if args.sequence_parallel:
raise RuntimeError(
"Using sequence parallelism requires setting the environment variable "
"CUDA_DEVICE_MAX_CONNECTIONS to 1")
if args.async_tensor_model_parallel_allreduce:
raise RuntimeError(
"Using async gradient all reduce requires setting the environment "
"variable CUDA_DEVICE_MAX_CONNECTIONS to 1")
_print_args(args) _print_args(args)
return args return args
......
...@@ -9,8 +9,8 @@ import numpy as np ...@@ -9,8 +9,8 @@ import numpy as np
import torch import torch
from megatron import (mpu, from megatron import update_num_microbatches
update_num_microbatches) from megatron.core import mpu, tensor_parallel
from .global_vars import get_args from .global_vars import get_args
from .utils import (unwrap_model, from .utils import (unwrap_model,
print_rank_0) print_rank_0)
...@@ -185,7 +185,7 @@ def get_rng_state(): ...@@ -185,7 +185,7 @@ def get_rng_state():
'np_rng_state': np.random.get_state(), 'np_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(), 'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state(), 'cuda_rng_state': torch.cuda.get_rng_state(),
'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states()} 'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()}
rng_state_list = None rng_state_list = None
if torch.distributed.is_initialized() and \ if torch.distributed.is_initialized() and \
...@@ -590,7 +590,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -590,7 +590,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array # Check for empty states array
if not rng_state['rng_tracker_states']: if not rng_state['rng_tracker_states']:
raise KeyError raise KeyError
mpu.get_cuda_rng_tracker().set_states( tensor_parallel.get_cuda_rng_tracker().set_states(
rng_state['rng_tracker_states']) rng_state['rng_tracker_states'])
else: # backward compatability else: # backward compatability
random.setstate(model_state_dict['random_rng_state']) random.setstate(model_state_dict['random_rng_state'])
...@@ -600,7 +600,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -600,7 +600,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array # Check for empty states array
if not model_state_dict['rng_tracker_states']: if not model_state_dict['rng_tracker_states']:
raise KeyError raise KeyError
mpu.get_cuda_rng_tracker().set_states( tensor_parallel.get_cuda_rng_tracker().set_states(
model_state_dict['rng_tracker_states']) model_state_dict['rng_tracker_states'])
except KeyError: except KeyError:
print_rank_0('Unable to load rng state from checkpoint {}. ' print_rank_0('Unable to load rng state from checkpoint {}. '
......
import megatron.core.parallel_state
import megatron.core.tensor_parallel
import megatron.core.utils
# Alias parallel_state as mpu, its legacy name
mpu = parallel_state
__all__ = [
"parallel_state",
"tensor_parallel",
"utils",
]
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Model and data parallel groups.""" """Model and data parallel groups."""
import torch import torch
from typing import Optional
from .utils import ensure_divisibility from .utils import GlobalMemoryBuffer
# Intra-layer model parallel group that the current rank belongs to. # Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None _TENSOR_MODEL_PARALLEL_GROUP = None
...@@ -45,17 +44,16 @@ _PIPELINE_GLOBAL_RANKS = None ...@@ -45,17 +44,16 @@ _PIPELINE_GLOBAL_RANKS = None
# rank when broadcasting weights from src to all other data parallel ranks # rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None _DATA_PARALLEL_GLOBAL_RANKS = None
# Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER = None
def is_unitialized(): def initialize_model_parallel(
"""Useful for code segments that may be accessed with or without mpu initialization""" tensor_model_parallel_size: int = 1,
return _DATA_PARALLEL_GROUP is None pipeline_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: Optional[int] = None,
pipeline_model_parallel_split_rank: Optional[int] = None,
def initialize_model_parallel(tensor_model_parallel_size_=1, ) -> None:
pipeline_model_parallel_size_=1,
virtual_pipeline_model_parallel_size_=None,
pipeline_model_parallel_split_rank_=None):
""" """
Initialize model data parallel groups. Initialize model data parallel groups.
...@@ -67,7 +65,6 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -67,7 +65,6 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
pipeline_model_parallel_split_rank: for models with both encoder and decoder, pipeline_model_parallel_split_rank: for models with both encoder and decoder,
rank in pipeline with split point. rank in pipeline with split point.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will the model pipeline. The present function will
...@@ -84,49 +81,48 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -84,49 +81,48 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
with a total of 16 GPUs, rank 0 to 7 belong to the first box and with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box. ranks 8 to 15 belong to the second box.
""" """
if torch.distributed.get_rank() == 0:
print('> initializing tensor model parallel with size {}'.format(
tensor_model_parallel_size_))
print('> initializing pipeline model parallel with size {}'.format(
pipeline_model_parallel_size_))
# Get world size and rank. Ensure some consistencies. # Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized() assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size() world_size: int = torch.distributed.get_world_size()
tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size) if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:
ensure_divisibility(world_size, raise RuntimeError(
tensor_model_parallel_size * pipeline_model_parallel_size) f"world_size ({world_size}) is not divisible by tensor_model_parallel_size "
data_parallel_size = world_size // (tensor_model_parallel_size * f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})"
pipeline_model_parallel_size) )
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size data_parallel_size: int = world_size // (tensor_model_parallel_size *
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size pipeline_model_parallel_size)
num_data_parallel_groups = world_size // data_parallel_size
num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
if virtual_pipeline_model_parallel_size_ is not None: num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
num_data_parallel_groups: int = world_size // data_parallel_size
if virtual_pipeline_model_parallel_size is not None:
if not pipeline_model_parallel_size > 2:
raise RuntimeError("pipeline-model-parallel size should be greater than 2 with "
"interleaved schedule")
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_ _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size
if pipeline_model_parallel_split_rank_ is not None: if pipeline_model_parallel_split_rank is not None:
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank_ _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
# Build the data-parallel groups. # Build the data-parallel groups.
global _DATA_PARALLEL_GROUP global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_GLOBAL_RANKS global _DATA_PARALLEL_GLOBAL_RANKS
assert _DATA_PARALLEL_GROUP is None, \ assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'
'data parallel group is already initialized'
all_data_parallel_group_ranks = [] all_data_parallel_group_ranks = []
for i in range(pipeline_model_parallel_size): for i in range(pipeline_model_parallel_size):
start_rank = i * num_pipeline_model_parallel_groups start_rank = i * num_pipeline_model_parallel_groups
end_rank = (i + 1) * num_pipeline_model_parallel_groups end_rank = (i + 1) * num_pipeline_model_parallel_groups
for j in range(tensor_model_parallel_size): for j in range(tensor_model_parallel_size):
ranks = range(start_rank + j, end_rank, ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
tensor_model_parallel_size)
all_data_parallel_group_ranks.append(list(ranks)) all_data_parallel_group_ranks.append(list(ranks))
group = torch.distributed.new_group(ranks) group = torch.distributed.new_group(ranks)
if rank in ranks: if rank in ranks:
...@@ -135,8 +131,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -135,8 +131,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# Build the model-parallel groups. # Build the model-parallel groups.
global _MODEL_PARALLEL_GROUP global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, \ assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'
'model parallel group is already initialized'
for i in range(data_parallel_size): for i in range(data_parallel_size):
ranks = [data_parallel_group_ranks[i] ranks = [data_parallel_group_ranks[i]
for data_parallel_group_ranks in all_data_parallel_group_ranks] for data_parallel_group_ranks in all_data_parallel_group_ranks]
...@@ -163,15 +158,13 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -163,15 +158,13 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
'pipeline model parallel group is already initialized' 'pipeline model parallel group is already initialized'
global _EMBEDDING_GROUP global _EMBEDDING_GROUP
global _EMBEDDING_GLOBAL_RANKS global _EMBEDDING_GLOBAL_RANKS
assert _EMBEDDING_GROUP is None, \ assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'
'embedding group is already initialized'
global _POSITION_EMBEDDING_GROUP global _POSITION_EMBEDDING_GROUP
global _POSITION_EMBEDDING_GLOBAL_RANKS global _POSITION_EMBEDDING_GLOBAL_RANKS
assert _POSITION_EMBEDDING_GROUP is None, \ assert _POSITION_EMBEDDING_GROUP is None, \
'position embedding group is already initialized' 'position embedding group is already initialized'
for i in range(num_pipeline_model_parallel_groups): for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, ranks = range(i, world_size, num_pipeline_model_parallel_groups)
num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks) group = torch.distributed.new_group(ranks)
if rank in ranks: if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group _PIPELINE_MODEL_PARALLEL_GROUP = group
...@@ -181,14 +174,14 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -181,14 +174,14 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
if len(ranks) > 1: if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]] embedding_ranks = [ranks[0], ranks[-1]]
position_embedding_ranks = [ranks[0]] position_embedding_ranks = [ranks[0]]
if pipeline_model_parallel_split_rank_ is not None: if pipeline_model_parallel_split_rank is not None:
if ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks: if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:
embedding_ranks = [ranks[0], embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank_], ranks[pipeline_model_parallel_split_rank],
ranks[-1]] ranks[-1]]
if ranks[pipeline_model_parallel_split_rank_] not in position_embedding_ranks: if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:
position_embedding_ranks = [ranks[0], position_embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank_]] ranks[pipeline_model_parallel_split_rank]]
else: else:
embedding_ranks = ranks embedding_ranks = ranks
position_embedding_ranks = ranks position_embedding_ranks = ranks
...@@ -205,6 +198,12 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -205,6 +198,12 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
if rank in ranks: if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
# Initialize global memory buffer
# This isn't really "parallel state" but there isn't another good place to
# put this. If we end up with a more generic initialization of megatron-core
# we could stick it there
_set_global_memory_buffer()
def model_parallel_is_initialized(): def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized.""" """Check if model and data parallel groups are initialized."""
...@@ -297,6 +296,12 @@ def set_pipeline_model_parallel_rank(rank): ...@@ -297,6 +296,12 @@ def set_pipeline_model_parallel_rank(rank):
_MPU_PIPELINE_MODEL_PARALLEL_RANK = rank _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
def set_pipeline_model_parallel_split_rank(rank):
"""Set pipeline model parallel split rank."""
global _MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
def get_tensor_model_parallel_rank(): def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group.""" """Return my rank for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK global _MPU_TENSOR_MODEL_PARALLEL_RANK
...@@ -313,57 +318,6 @@ def get_pipeline_model_parallel_rank(): ...@@ -313,57 +318,6 @@ def get_pipeline_model_parallel_rank():
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False):
"""Compute the number of transformer layers resident on the current rank."""
if get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model:
assert args.pipeline_model_parallel_split_rank is not None
# When a standalone embedding stage is used, a rank is taken from
# the encoder's ranks, to be used for the encoder's embedding
# layer. This way, the rank referenced by the 'split rank' remains
# the same whether or not a standalone embedding stage is used.
num_ranks_in_encoder = (
args.pipeline_model_parallel_split_rank - 1
if args.standalone_embedding_stage else
args.pipeline_model_parallel_split_rank
)
num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
assert args.encoder_num_layers % num_ranks_in_encoder == 0, \
'encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder)
assert args.decoder_num_layers % num_ranks_in_decoder == 0, \
'decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder)
if is_pipeline_stage_before_split():
num_layers = (
0
if args.standalone_embedding_stage
and get_pipeline_model_parallel_rank() == 0 else
args.encoder_num_layers // num_ranks_in_encoder
)
else:
num_layers = args.decoder_num_layers // num_ranks_in_decoder
else:
assert args.num_layers == args.encoder_num_layers
assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'num_layers must be divisible by transformer_pipeline_model_parallel_size'
# When a standalone embedding stage is used, all transformer layers
# are divided among pipeline rank >= 1, while on pipeline rank 0,
# ranks either contain the input embedding layer (virtual pp rank 0),
# or no layers at all (virtual pp rank >= 1).
num_layers = (
0
if args.standalone_embedding_stage
and get_pipeline_model_parallel_rank() == 0 else
args.num_layers // args.transformer_pipeline_model_parallel_size
)
else:
if not is_decoder:
num_layers = args.encoder_num_layers
else:
num_layers = args.decoder_num_layers
return num_layers
def is_pipeline_first_stage(ignore_virtual=False): def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise.""" """Return True if in the first pipeline model-parallel stage, False otherwise."""
...@@ -484,18 +438,23 @@ def get_data_parallel_src_rank(): ...@@ -484,18 +438,23 @@ def get_data_parallel_src_rank():
def get_pipeline_model_parallel_first_rank(): def get_pipeline_model_parallel_first_rank():
"""Return the global rank of the first process in the pipeline for the
current tensor parallel group"""
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized" "Pipeline parallel group is not initialized"
return _PIPELINE_GLOBAL_RANKS[0] return _PIPELINE_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_last_rank(): def get_pipeline_model_parallel_last_rank():
"""Return the global rank of the last process in the pipeline for the
current tensor parallel group"""
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized" "Pipeline parallel group is not initialized"
last_rank_local = get_pipeline_model_parallel_world_size() - 1 last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local] return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_next_rank(): def get_pipeline_model_parallel_next_rank():
"""Return the global rank that follows the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized" "Pipeline parallel group is not initialized"
rank_in_pipeline = get_pipeline_model_parallel_rank() rank_in_pipeline = get_pipeline_model_parallel_rank()
...@@ -504,6 +463,7 @@ def get_pipeline_model_parallel_next_rank(): ...@@ -504,6 +463,7 @@ def get_pipeline_model_parallel_next_rank():
def get_pipeline_model_parallel_prev_rank(): def get_pipeline_model_parallel_prev_rank():
"""Return the global rank that preceeds the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized" "Pipeline parallel group is not initialized"
rank_in_pipeline = get_pipeline_model_parallel_rank() rank_in_pipeline = get_pipeline_model_parallel_rank()
...@@ -520,6 +480,17 @@ def get_data_parallel_rank(): ...@@ -520,6 +480,17 @@ def get_data_parallel_rank():
"""Return my rank for the data parallel group.""" """Return my rank for the data parallel group."""
return torch.distributed.get_rank(group=get_data_parallel_group()) return torch.distributed.get_rank(group=get_data_parallel_group())
def _set_global_memory_buffer():
"""Initialize global buffer"""
global _GLOBAL_MEMORY_BUFFER
assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized'
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
def get_global_memory_buffer():
"""Return the global GlobalMemoryBuffer object"""
assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
return _GLOBAL_MEMORY_BUFFER
def destroy_model_parallel(): def destroy_model_parallel():
"""Set the groups to none.""" """Set the groups to none."""
...@@ -535,3 +506,17 @@ def destroy_model_parallel(): ...@@ -535,3 +506,17 @@ def destroy_model_parallel():
_EMBEDDING_GROUP = None _EMBEDDING_GROUP = None
global _POSITION_EMBEDDING_GROUP global _POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP = None _POSITION_EMBEDDING_GROUP = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
global _GLOBAL_MEMORY_BUFFER
_GLOBAL_MEMORY_BUFFER = None
from .cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data
from .layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes,
param_is_not_tensor_parallel_duplicate,
linear_with_grad_accumulation_and_async_allreduce
)
from .mappings import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
scatter_to_tensor_model_parallel_region,
scatter_to_sequence_parallel_region,
)
from .random import (
checkpoint,
get_cuda_rng_tracker,
model_parallel_cuda_manual_seed,
)
from .utils import (
split_tensor_along_last_dim,
split_tensor_into_1d_equal_chunks,
gather_split_1d_tensor,
)
__all__ = [
# cross_entropy.py
"vocab_parallel_cross_entropy",
# data.py
"broadcast_data",
#layers.py
"ColumnParallelLinear",
"RowParallelLinear",
"VocabParallelEmbedding",
"set_tensor_model_parallel_attributes",
"set_defaults_if_not_set_tensor_model_parallel_attributes",
"copy_tensor_model_parallel_attributes",
"param_is_not_tensor_parallel_duplicate",
"linear_with_grad_accumulation_and_async_allreduce",
# mappings.py
"copy_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region",
"gather_from_sequence_parallel_region",
# "reduce_from_tensor_model_parallel_region",
"scatter_to_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region",
# random.py
"checkpoint",
"get_cuda_rng_tracker",
"model_parallel_cuda_manual_seed",
# utils.py
"split_tensor_along_last_dim",
"split_tensor_into_1d_equal_chunks",
"gather_split_1d_tensor",
]
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch import torch
from .initialize import get_tensor_model_parallel_group from megatron.core.parallel_state import (
from .initialize import get_tensor_model_parallel_rank get_tensor_model_parallel_group,
from .initialize import get_tensor_model_parallel_world_size get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
)
from .utils import VocabUtility from .utils import VocabUtility
class _VocabParallelCrossEntropy(torch.autograd.Function): class _VocabParallelCrossEntropy(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, vocab_parallel_logits, target): def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0):
# Maximum value along vocab dimension across all GPUs. # Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
...@@ -20,7 +22,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): ...@@ -20,7 +22,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
op=torch.distributed.ReduceOp.MAX, op=torch.distributed.ReduceOp.MAX,
group=get_tensor_model_parallel_group()) group=get_tensor_model_parallel_group())
# Subtract the maximum value. # Subtract the maximum value.
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1)
# Get the partition's vocab indecies # Get the partition's vocab indecies
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
...@@ -62,8 +64,32 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): ...@@ -62,8 +64,32 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# Loss = log(sum(exp(logits))) - predicted-logit. # Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits loss = torch.log(sum_exp_logits) - predicted_logits
# Store softmax, target-mask and masked-target for backward pass. # Normalize and optionally smooth logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
vocab_size = exp_logits.size(-1)
if label_smoothing > 0:
"""
We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth.
= (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt})
= (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
= ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
= (K * (1 - alpha) - 1) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i} y_i
= (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K
From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py
"""
assert 1.0 > label_smoothing > 0.0
smoothing = label_smoothing * vocab_size / (vocab_size - 1)
# Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs.
log_probs = torch.log(exp_logits)
mean_log_probs = log_probs.mean(dim=-1)
loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs
ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
# Store softmax, target-mask and masked-target for backward pass.
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss return loss
...@@ -73,6 +99,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): ...@@ -73,6 +99,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# Retreive tensors from the forward path. # Retreive tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors softmax, target_mask, masked_target_1d = ctx.saved_tensors
label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size
# All the inputs have softmax as thier gradient. # All the inputs have softmax as thier gradient.
grad_input = softmax grad_input = softmax
...@@ -83,15 +110,34 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): ...@@ -83,15 +110,34 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# Add the gradient from matching classes. # Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], arange_1d = torch.arange(start=0, end=grad_2d.size()[0],
device=grad_2d.device) device=grad_2d.device)
grad_2d[arange_1d, masked_target_1d] -= (
1.0 - target_mask.view(-1).float()) softmax_update = 1.0 - target_mask.view(-1).float()
if label_smoothing > 0:
smoothing = label_smoothing * vocab_size / (vocab_size - 1)
grad_2d[arange_1d, masked_target_1d] -= (1.0 - smoothing) * softmax_update
average_grad = 1 / vocab_size
grad_2d[arange_1d, :] -= smoothing * average_grad
else:
grad_2d[arange_1d, masked_target_1d] -= softmax_update
# Finally elementwise multiplication with the output gradients. # Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1)) grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input, None return grad_input, None, None
def vocab_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0):
"""
Performs cross entropy loss when logits are split across tensor parallel ranks
Arguments:
vocab_parallel_logits: logits split across tensor parallel ranks
dimension is [sequence_length, batch_size, hidden_size]
target: correct vocab ids of dimseion [sequence_length, micro_batch_size]
def vocab_parallel_cross_entropy(vocab_parallel_logits, target): lobal_smoothing: smoothing factor, must be in range [0.0, 1.0)
"""Helper function for the cross entropy.""" default is no smoothing (=0.0)
return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target) """
return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target, label_smoothing)
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
import torch import torch
from .initialize import get_tensor_model_parallel_group from megatron.core.parallel_state import (
from .initialize import get_tensor_model_parallel_rank get_tensor_model_parallel_group,
from .initialize import get_tensor_model_parallel_src_rank get_tensor_model_parallel_rank,
get_tensor_model_parallel_src_rank,
)
_MAX_DATA_DIM = 5 _MAX_DATA_DIM = 5
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch # Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch # repo: https://github.com/pytorch/pytorch
import math import math
import os
from typing import Optional
import warnings
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.nn.init as init import torch.nn.init as init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from .initialize import get_tensor_model_parallel_rank from megatron.core.parallel_state import (
from .initialize import get_tensor_model_parallel_world_size get_tensor_model_parallel_rank,
from .initialize import get_tensor_model_parallel_group get_tensor_model_parallel_world_size,
from .mappings import copy_to_tensor_model_parallel_region get_tensor_model_parallel_group,
from .mappings import gather_from_tensor_model_parallel_region get_global_memory_buffer,
from .mappings import gather_from_sequence_parallel_region )
from .mappings import reduce_from_tensor_model_parallel_region from .mappings import (
from .mappings import scatter_to_tensor_model_parallel_region copy_to_tensor_model_parallel_region,
from .mappings import reduce_scatter_to_sequence_parallel_region gather_from_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region,
)
from .random import get_cuda_rng_tracker from .random import get_cuda_rng_tracker
from .utils import divide from .utils import (
from .utils import split_tensor_along_last_dim divide,
from .utils import VocabUtility split_tensor_along_last_dim,
from megatron import get_args, get_global_memory_buffer VocabUtility,
)
_grad_accum_fusion_available = True
try:
import fused_weight_gradient_mlp_cuda
except ImportError:
_grad_accum_fusion_available = False
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim': -1, 'partition_dim': -1,
...@@ -81,7 +94,8 @@ def _initialize_affine_weight_gpu(weight, init_method, ...@@ -81,7 +94,8 @@ def _initialize_affine_weight_gpu(weight, init_method,
def _initialize_affine_weight_cpu(weight, output_size, input_size, def _initialize_affine_weight_cpu(weight, output_size, input_size,
per_partition_size, partition_dim, per_partition_size, partition_dim,
init_method, stride=1, init_method, stride=1,
return_master_weight=False): return_master_weight=False,
*, params_dtype=torch.float32):
"""Initialize affine weight for model parallel. """Initialize affine weight for model parallel.
Build the master weight on all processes and scatter Build the master weight on all processes and scatter
...@@ -97,8 +111,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size, ...@@ -97,8 +111,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
dtype=torch.float, dtype=torch.float,
requires_grad=False) requires_grad=False)
init_method(master_weight) init_method(master_weight)
args = get_args() master_weight = master_weight.to(dtype=params_dtype)
master_weight = master_weight.to(dtype=args.params_dtype)
# Split and copy # Split and copy
per_partition_per_stride_size = divide(per_partition_size, stride) per_partition_per_stride_size = divide(per_partition_size, stride)
...@@ -123,11 +136,19 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -123,11 +136,19 @@ class VocabParallelEmbedding(torch.nn.Module):
Arguments: Arguments:
num_embeddings: vocabulary size. num_embeddings: vocabulary size.
embedding_dim: size of hidden state. embedding_dim: size of hidden state.
Keyword Arguments:
init_method: method to initialize weights. init_method: method to initialize weights.
params_dtype
use_cpu_initialization
perform_initialization
""" """
def __init__(self, num_embeddings, embedding_dim, def __init__(self, num_embeddings: int, embedding_dim: int, *,
init_method=init.xavier_normal_): init_method=init.xavier_normal_,
params_dtype: torch.dtype=torch.float32,
use_cpu_initialization: bool=False,
perform_initialization: bool=True):
super(VocabParallelEmbedding, self).__init__() super(VocabParallelEmbedding, self).__init__()
# Keep the input dimensions. # Keep the input dimensions.
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
...@@ -149,20 +170,20 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -149,20 +170,20 @@ class VocabParallelEmbedding(torch.nn.Module):
self.vocab_start_index self.vocab_start_index
# Allocate weights and initialize. # Allocate weights and initialize.
args = get_args() if use_cpu_initialization:
if args.use_cpu_initialization:
self.weight = Parameter(torch.empty( self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim, self.num_embeddings_per_partition, self.embedding_dim,
dtype=args.params_dtype)) dtype=params_dtype))
if args.perform_initialization: if perform_initialization:
_initialize_affine_weight_cpu( _initialize_affine_weight_cpu(
self.weight, self.num_embeddings, self.embedding_dim, self.weight, self.num_embeddings, self.embedding_dim,
self.num_embeddings_per_partition, 0, init_method) self.num_embeddings_per_partition, 0, init_method,
params_dtype=params_dtype)
else: else:
self.weight = Parameter(torch.empty( self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim, self.num_embeddings_per_partition, self.embedding_dim,
device=torch.cuda.current_device(), dtype=args.params_dtype)) device=torch.cuda.current_device(), dtype=params_dtype))
if args.perform_initialization: if perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method, _initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=1) partition_dim=0, stride=1)
...@@ -190,10 +211,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -190,10 +211,7 @@ class VocabParallelEmbedding(torch.nn.Module):
class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
""" """See linear_with_grad_accumulation_and_async_allreduce"""
Linear layer execution with asynchronous communication and gradient accumulation
fusion in backprop.
"""
@staticmethod @staticmethod
def forward(ctx, input, weight, bias, gradient_accumulation_fusion, def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
...@@ -203,7 +221,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -203,7 +221,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce ctx.async_grad_allreduce = async_grad_allreduce
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
if sequence_parallel: if sequence_parallel:
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size()) dim_size = list(input.size())
...@@ -228,7 +246,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -228,7 +246,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
def backward(ctx, grad_output): def backward(ctx, grad_output):
input, weight = ctx.saved_tensors input, weight = ctx.saved_tensors
use_bias = ctx.use_bias use_bias = ctx.use_bias
if ctx.sequence_parallel: if ctx.sequence_parallel:
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size()) dim_size = list(input.size())
...@@ -241,9 +259,8 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -241,9 +259,8 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
input, input,
group=get_tensor_model_parallel_group(), async_op=True) group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of intput gradient computation shortly (3us) to have # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather scheduled first and have GPU resources allocated # gather is scheduled before the input gradient computation
_ = torch.empty(1, device=grad_output.device) + 1
total_input = all_gather_buffer total_input = all_gather_buffer
else: else:
total_input = input total_input = input
...@@ -257,15 +274,14 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -257,15 +274,14 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
grad_output.shape[2]) grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input = total_input.view(total_input.shape[0] * total_input.shape[1],
total_input.shape[2]) total_input.shape[2])
if ctx.async_grad_allreduce: if ctx.async_grad_allreduce:
# Asynchronous all-reduce # Asynchronous all-reduce
handle = torch.distributed.all_reduce( handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True) grad_input, group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# all-reduce scheduled first and have GPU resources allocated # all-reduce is scheduled before the weight gradient computation
_ = torch.empty(1, device=grad_output.device) + 1
if ctx.sequence_parallel: if ctx.sequence_parallel:
assert not ctx.async_grad_allreduce assert not ctx.async_grad_allreduce
dim_size = list(input.size()) dim_size = list(input.size())
...@@ -273,17 +289,20 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -273,17 +289,20 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False) requires_grad=False)
# reduce_scatter # reduce_scatter
handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input, handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input,
group=get_tensor_model_parallel_group(), group=get_tensor_model_parallel_group(),
async_op=True) async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# reduce scatter scheduled first and have GPU resources allocated # reduce scatter is scheduled before the weight gradient computation
_ = torch.empty(1, device=grad_output.device) + 1
if ctx.gradient_accumulation_fusion: if ctx.gradient_accumulation_fusion:
import fused_dense_cuda if weight.main_grad.dtype == torch.float32:
fused_dense_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad) fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad)
elif weight.main_grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, weight.main_grad)
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
grad_weight = None grad_weight = None
else: else:
grad_weight = grad_output.t().matmul(total_input) grad_weight = grad_output.t().matmul(total_input)
...@@ -298,6 +317,94 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -298,6 +317,94 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None, None return grad_input, grad_weight, grad_bias, None, None, None
def linear_with_grad_accumulation_and_async_allreduce(
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
gradient_accumulation_fusion: bool,
async_grad_allreduce: bool,
sequence_parallel_enabled: bool,
) -> torch.Tensor:
"""Linear layer execution with asynchronous communication and
gradient accumulation fusion in backprop.
This has the option to accumulate the result of backprop
calculation into an existing gradient buffer, preventing the need
to do an additional addition kernel after the gradient
calculation.
Additionally, the tensor parallel all reduce of the input
gradients can be done asynchronously with the calculation of
the weight gradients.
In the case of sequence parallelism, the reduce scatter of the
input gradients is done asynchronously with the calcluation of the
weight gradients.
Use of this module requires that the environment variable
CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective
operations, noted in the code, that should be scheduled before
compute kernels to overlap the communication with the computation,
which is necessary for a speedup but not for correctness so that
ordering isn't imposed by the scheduler. Setting
CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled
in the order they are called.
Arguments:
input (torch.Tensor required): input like torch.nn.functional.linear
weight (torch.Tensor required): weight like torch.nn.functional.linear
bias (torch.Tensor optional): bias like torch.nn.functional.linear
gradient_accumulation_fusion (bool required): Perform the gradient
accumulation fusion, requires the custom CUDA extension
fused_weight_gradient_mlp_cuda module. To use
gradient_accumulation_fusion you must install APEX with
--cpp_ext and --cuda_ext. For example: "pip install
--global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\"
" Note that the extension requires CUDA>=11. Otherwise, you
must turn off gradient accumulation fusion."
async_grad_allreduce (bool required): Do the allreduce of input
gradients asyncronously with the computation of weight
gradients. If sequence_parallel_enabled is True, this must be
False, as no all reduce is performed.
sequence_parallel_enabled (bool required): Indicates that sequence
parallelism is used and thus in the forward pass the input is
all gathered, and the backward pass the input gradients are
reduce scattered.
"""
args = [
input,
weight,
bias,
gradient_accumulation_fusion,
async_grad_allreduce,
sequence_parallel_enabled,
]
if not linear_with_grad_accumulation_and_async_allreduce.warned:
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if sequence_parallel_enabled:
warnings.warn(
"When using sequence parallelism it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup")
linear_with_grad_accumulation_and_async_allreduce.warned = True
if async_grad_allreduce:
warnings.warn(
"When using async grad allreduce it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup")
linear_with_grad_accumulation_and_async_allreduce.warned = True
with torch.cuda.amp.autocast(enabled=False):
return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
linear_with_grad_accumulation_and_async_allreduce.warned = False
class ColumnParallelLinear(torch.nn.Module): class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism. """Linear layer with column parallelism.
...@@ -308,6 +415,8 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -308,6 +415,8 @@ class ColumnParallelLinear(torch.nn.Module):
Arguments: Arguments:
input_size: first dimension of matrix A. input_size: first dimension of matrix A.
output_size: second dimension of matrix A. output_size: second dimension of matrix A.
Keyword Arguments
bias: If true, add bias bias: If true, add bias
gather_output: If true, call all-gather on output and make Y available gather_output: If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output to all GPUs, otherwise, every GPU will have its output
...@@ -321,12 +430,25 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -321,12 +430,25 @@ class ColumnParallelLinear(torch.nn.Module):
skip_bias_add: This was added to enable performance optimations where bias skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip can be fused with other elementwise operations. we skip
adding bias but instead return it. adding bias but instead return it.
async_tensor_model_parallel_allreduce:
params_dtype:
use_cpu_initialization:
gradient_accumulation_fusion:
sequence_parallel_enabled:
""" """
def __init__(self, input_size, output_size, bias=True, gather_output=True, def __init__(self, input_size, output_size, *,
bias=True, gather_output=True,
init_method=init.xavier_normal_, stride=1, init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False, keep_master_weight_for_test=False,
skip_bias_add=False): skip_bias_add=False,
async_tensor_model_parallel_allreduce=True,
params_dtype=torch.float32,
use_cpu_initialization=False,
perform_initialization=True,
gradient_accumulation_fusion=False,
sequence_parallel_enabled: bool = False,
):
super(ColumnParallelLinear, self).__init__() super(ColumnParallelLinear, self).__init__()
# Keep input parameters # Keep input parameters
...@@ -342,12 +464,11 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -342,12 +464,11 @@ class ColumnParallelLinear(torch.nn.Module):
# Note: torch.nn.functional.linear performs XA^T + b and as a result # Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose. # we allocate the transpose.
# Initialize weight. # Initialize weight.
args = get_args() if use_cpu_initialization:
if args.use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size_per_partition, self.weight = Parameter(torch.empty(self.output_size_per_partition,
self.input_size, self.input_size,
dtype=args.params_dtype)) dtype=params_dtype))
if args.perform_initialization: if perform_initialization:
self.master_weight = _initialize_affine_weight_cpu( self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size, self.weight, self.output_size, self.input_size,
self.output_size_per_partition, 0, init_method, self.output_size_per_partition, 0, init_method,
...@@ -355,51 +476,88 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -355,51 +476,88 @@ class ColumnParallelLinear(torch.nn.Module):
else: else:
self.weight = Parameter(torch.empty( self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size, self.output_size_per_partition, self.input_size,
device=torch.cuda.current_device(), dtype=args.params_dtype)) device=torch.cuda.current_device(), dtype=params_dtype))
if args.perform_initialization: if perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method, _initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=stride) partition_dim=0, stride=stride)
if bias: if bias:
if args.use_cpu_initialization: if use_cpu_initialization:
self.bias = Parameter(torch.empty( self.bias = Parameter(torch.empty(
self.output_size_per_partition, dtype=args.params_dtype)) self.output_size_per_partition, dtype=params_dtype))
else: else:
self.bias = Parameter(torch.empty( self.bias = Parameter(torch.empty(
self.output_size_per_partition, self.output_size_per_partition,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=args.params_dtype)) dtype=params_dtype))
set_tensor_model_parallel_attributes(self.bias, True, 0, stride) set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
self.bias.zero_() self.bias.zero_()
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.async_tensor_model_parallel_allreduce = ( self.async_tensor_model_parallel_allreduce = (
args.async_tensor_model_parallel_allreduce and async_tensor_model_parallel_allreduce and
world_size > 1) world_size > 1)
self.sequence_parallel = ( if sequence_parallel_enabled:
args.sequence_parallel and if world_size <= 1:
world_size > 1) warnings.warn(
assert not self.async_tensor_model_parallel_allreduce or \ f"`sequence_parallel_enabled` is set to `True`, but tensor model parallel size is {world_size}. "
not self.sequence_parallel f"Disabling sequence parallel."
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion )
sequence_parallel_enabled = False
self.sequence_parallel_enabled = sequence_parallel_enabled
if gradient_accumulation_fusion:
if not _grad_accum_fusion_available:
raise RuntimeError(
"ColumnParallelLinear was called with gradient_accumulation_fusion set "
"to True but the custom CUDA extension fused_weight_gradient_mlp_cuda "
"module is not found. To use gradient_accumulation_fusion you must "
"install APEX with --cpp_ext and --cuda_ext. For example: "
"pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" "
"Note that the extension requires CUDA>=11. Otherwise, you must turn off "
"gradient accumulation fusion."
)
self.gradient_accumulation_fusion = gradient_accumulation_fusion
if self.async_tensor_model_parallel_allreduce and self.sequence_parallel_enabled:
raise RuntimeError(
"`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` "
"cannot be enabled at the same time."
)
def forward(self, input_): def forward(self, input_):
"""Forward of ColumnParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
Returns:
- output
- bias
"""
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
if self.async_tensor_model_parallel_allreduce or \ if self.async_tensor_model_parallel_allreduce or \
self.sequence_parallel: self.sequence_parallel_enabled:
input_parallel = input_ input_parallel = input_
else: else:
input_parallel = copy_to_tensor_model_parallel_region(input_) input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply( output_parallel = linear_with_grad_accumulation_and_async_allreduce(
input_parallel, self.weight, bias, self.gradient_accumulation_fusion, input=input_parallel,
self.async_tensor_model_parallel_allreduce, self.sequence_parallel) weight=self.weight,
bias=bias,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=self.async_tensor_model_parallel_allreduce,
sequence_parallel_enabled=self.sequence_parallel_enabled,
)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
assert not self.sequence_parallel assert not self.sequence_parallel_enabled
output = gather_from_tensor_model_parallel_region(output_parallel) output = gather_from_tensor_model_parallel_region(output_parallel)
else: else:
output = output_parallel output = output_parallel
...@@ -422,6 +580,8 @@ class RowParallelLinear(torch.nn.Module): ...@@ -422,6 +580,8 @@ class RowParallelLinear(torch.nn.Module):
Arguments: Arguments:
input_size: first dimension of matrix A. input_size: first dimension of matrix A.
output_size: second dimension of matrix A. output_size: second dimension of matrix A.
Keyword Arguments:
bias: If true, add bias. Note that bias is not parallelized. bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split split across the GPUs and we do not split
...@@ -435,13 +595,24 @@ class RowParallelLinear(torch.nn.Module): ...@@ -435,13 +595,24 @@ class RowParallelLinear(torch.nn.Module):
skip_bias_add: This was added to enable performance optimization where bias skip_bias_add: This was added to enable performance optimization where bias
can be fused with other elementwise operations. We skip can be fused with other elementwise operations. We skip
adding bias but instead return it. adding bias but instead return it.
params_dtype:
use_cpu_initialization:
perform_initialization:
gradient_accumulation_fusion:
sequence_parallel_enabled:
""" """
def __init__(self, input_size, output_size, bias=True, def __init__(self, input_size, output_size, *,
input_is_parallel=False, bias=True, input_is_parallel=False,
init_method=init.xavier_normal_, stride=1, init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False, keep_master_weight_for_test=False,
skip_bias_add=False): skip_bias_add=False,
params_dtype=torch.float32,
use_cpu_initialization=False,
perform_initialization=True,
gradient_accumulation_fusion=False,
sequence_parallel_enabled: bool = False,
):
super(RowParallelLinear, self).__init__() super(RowParallelLinear, self).__init__()
# Keep input parameters # Keep input parameters
...@@ -452,61 +623,78 @@ class RowParallelLinear(torch.nn.Module): ...@@ -452,61 +623,78 @@ class RowParallelLinear(torch.nn.Module):
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size) self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.gradient_accumulation_fusion = gradient_accumulation_fusion
self.sequence_parallel_enabled = sequence_parallel_enabled
if self.sequence_parallel_enabled and not self.input_is_parallel:
raise RuntimeError("To enable `sequence_parallel_enabled`, `input_is_parallel` must be `True`")
# Parameters. # Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result # Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose. # we allocate the transpose.
# Initialize weight. # Initialize weight.
args = get_args() if use_cpu_initialization:
if args.use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size, self.weight = Parameter(torch.empty(self.output_size,
self.input_size_per_partition, self.input_size_per_partition,
dtype=args.params_dtype)) dtype=params_dtype))
if args.perform_initialization: if perform_initialization:
self.master_weight = _initialize_affine_weight_cpu( self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size, self.weight, self.output_size, self.input_size,
self.input_size_per_partition, 1, init_method, self.input_size_per_partition, 1, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test) stride=stride, return_master_weight=keep_master_weight_for_test,
params_dtype=params_dtype)
else: else:
self.weight = Parameter(torch.empty( self.weight = Parameter(torch.empty(
self.output_size, self.input_size_per_partition, self.output_size, self.input_size_per_partition,
device=torch.cuda.current_device(), dtype=args.params_dtype)) device=torch.cuda.current_device(), dtype=params_dtype))
if args.perform_initialization: if perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method, _initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=1, stride=stride) partition_dim=1, stride=stride)
if bias: if bias:
if args.use_cpu_initialization: if use_cpu_initialization:
self.bias = Parameter(torch.empty(self.output_size, self.bias = Parameter(torch.empty(self.output_size,
dtype=args.params_dtype)) dtype=params_dtype))
else: else:
self.bias = Parameter(torch.empty( self.bias = Parameter(torch.empty(
self.output_size, device=torch.cuda.current_device(), self.output_size, device=torch.cuda.current_device(),
dtype=args.params_dtype)) dtype=params_dtype))
setattr(self.bias, 'sequence_parallel', args.sequence_parallel) setattr(self.bias, 'sequence_parallel', sequence_parallel_enabled)
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
self.bias.zero_() self.bias.zero_()
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.sequence_parallel = args.sequence_parallel
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
def forward(self, input_): def forward(self, input_):
"""Forward of RowParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
Returns:
- output
- bias
"""
# Set up backprop all-reduce. # Set up backprop all-reduce.
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
else: else:
assert not self.sequence_parallel assert not self.sequence_parallel_enabled
input_parallel = scatter_to_tensor_model_parallel_region(input_) input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply( output_parallel = linear_with_grad_accumulation_and_async_allreduce(
input_parallel, self.weight, None, input=input_parallel,
self.gradient_accumulation_fusion, None, None) weight=self.weight,
bias=None,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=False,
sequence_parallel_enabled=False,
)
# All-reduce across all the partitions. # All-reduce across all the partitions.
if self.sequence_parallel: if self.sequence_parallel_enabled:
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
else: else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel) output_ = reduce_from_tensor_model_parallel_region(output_parallel)
......
...@@ -2,7 +2,11 @@ ...@@ -2,7 +2,11 @@
import torch import torch
from .initialize import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank from megatron.core.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group,
)
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch # Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch # repo: https://github.com/pytorch/pytorch
...@@ -11,13 +10,19 @@ from torch import _C ...@@ -11,13 +10,19 @@ from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager from torch.cuda import _lazy_call, device as device_ctx_manager
from torch.utils.checkpoint import detach_variable from torch.utils.checkpoint import detach_variable
from megatron.memory import allocate_mem_buff from megatron.core.parallel_state import (
get_data_parallel_rank,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from .initialize import get_data_parallel_rank from .utils import (
from .initialize import get_tensor_model_parallel_group split_tensor_into_1d_equal_chunks,
from .initialize import get_tensor_model_parallel_rank gather_split_1d_tensor,
from .initialize import get_tensor_model_parallel_world_size )
from megatron.core.utils import safely_set_viewless_tensor_data
# Default name for the model parallel rng tracker. # Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
...@@ -56,117 +61,6 @@ def _set_cuda_rng_state(new_state, device=-1): ...@@ -56,117 +61,6 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call(cb) _lazy_call(cb)
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
"""Break a tensor into equal 1D chunks."""
partition_size = torch.numel(tensor) // \
get_tensor_model_parallel_world_size()
start_index = partition_size * get_tensor_model_parallel_rank()
end_index = start_index + partition_size
if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data
def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
numel_gathered = torch.numel(tensor) * \
get_tensor_model_parallel_world_size()
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# TODO: This API is experimental in pytorch (as of Feb 2022) and
# this might break in future pytorch releases. We chose this API
# as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch.distributed._all_gather_base(gathered, tensor,
group=get_tensor_model_parallel_group())
return gathered
def _kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor.
View tensors have the undesirable side-affect of retaining a reference
to the originally-viewed tensor, even after manually setting the '.data'
field. This method creates a new tensor that links to the old tensor's
data, without linking the viewed tensor, referenced via the '._base'
field.
'''
out = torch.empty(
(1,),
dtype = inp.dtype,
device = inp.device,
requires_grad = requires_grad,
)
out.data = inp.data
return out
class MakeViewlessTensor(torch.autograd.Function):
'''
Autograd function to make a viewless tensor.
This function should be used in cases where the computation graph needs
to be propagated, but we only want a viewless tensor (e.g.,
ParallelTransformer's hidden_states). Call this function by passing
'keep_graph = True' to 'make_viewless_tensor()'.
'''
@staticmethod
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def make_viewless_tensor(inp, requires_grad, keep_graph):
'''
Entry-point for creating viewless tensors.
This method should be used, rather than calling 'MakeViewlessTensor'
or '_kernel_make_viewless_tensor' directly. This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor.
'''
# return tensor as-is, if not a 'view'
if inp._base is None:
return inp
# create viewless tensor
if keep_graph:
return MakeViewlessTensor.apply(inp, requires_grad)
else:
return _kernel_make_viewless_tensor(inp, requires_grad)
def assert_viewless_tensor(tensor, extra_msg = None):
'''Assert that a tensor is not a view (i.e., its '._base' field is
not set).'''
if isinstance(tensor, list):
[ assert_viewless_tensor(t) for t in tensor ]
return tensor
if not isinstance(tensor, torch.Tensor):
return tensor
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). %s"
) % extra_msg
return tensor
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
'''Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
'''
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
tensor.data = new_data_tensor
class CudaRNGStatesTracker: class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states. """Tracker for the cuda RNG states.
...@@ -271,13 +165,6 @@ def model_parallel_cuda_manual_seed(seed): ...@@ -271,13 +165,6 @@ def model_parallel_cuda_manual_seed(seed):
# Data parallel gets the original seed. # Data parallel gets the original seed.
data_parallel_seed = seed data_parallel_seed = seed
if torch.distributed.get_rank() == 0:
print('> initializing model parallel cuda seeds on global rank {}, '
'model parallel rank {}, and data parallel rank {} with '
'model parallel seed: {} and data parallel seed: {}'.format(
torch.distributed.get_rank(), get_tensor_model_parallel_rank(),
get_data_parallel_rank(), tensor_model_parallel_seed,
data_parallel_seed), flush=True)
_CUDA_RNG_STATE_TRACKER.reset() _CUDA_RNG_STATE_TRACKER.reset()
# Set the default state. # Set the default state.
torch.cuda.manual_seed(data_parallel_seed) torch.cuda.manual_seed(data_parallel_seed)
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
from typing import List, Sequence
from megatron.core.utils import divide
from megatron.core import parallel_state
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]:
""" Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
""" Break a tensor into equal 1D chunks across tensor parallel ranks.
Returns a Tensor or View with this rank's portion of the data.
Arguments:
tensor: The tensor to split
Keyword Arguments:
new_buffer (bool): If True, returns a new Tensor.
If False, returns a view into the existing Tensor.
Default is False
"""
partition_size = torch.numel(tensor) // \
parallel_state.get_tensor_model_parallel_world_size()
start_index = partition_size * parallel_state.get_tensor_model_parallel_rank()
end_index = start_index + partition_size
if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data
def gather_split_1d_tensor(tensor):
""" Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor
model parallel ranks.
Returns a new Tensor with the gathered data.
Arguments:
tensor: A Tensor or view of this rank's portion of the data.
"""
numel_gathered = torch.numel(tensor) * \
parallel_state.get_tensor_model_parallel_world_size()
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# TODO: This API is experimental in pytorch (as of Feb 2022) and
# this might break in future pytorch releases. We chose this API
# as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch.distributed._all_gather_base(gathered, tensor,
group=parallel_state.get_tensor_model_parallel_group())
return gathered
class VocabUtility:
""" Split the vocabulary into `world_size` chunks and return the first
and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [fist, last)
"""
@staticmethod
def vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size: int, rank, world_size: int
) -> Sequence[int]:
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
@staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size
)
"""Utility functions used throughout Megatron core"""
from functools import reduce
import operator
import torch
from megatron.core import parallel_state
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator
)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
are not used concurrently."""
def __init__(self):
self.buffer = {}
def get_tensor(self, tensor_shape, dtype, name):
required_len = reduce(operator.mul, tensor_shape, 1)
if self.buffer.get((name, dtype), None) is None or \
self.buffer[(name, dtype)].numel() < required_len:
self.buffer[(name, dtype)] = \
torch.empty(required_len,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False)
return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
def _kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor.
View tensors have the undesirable side-affect of retaining a reference
to the originally-viewed tensor, even after manually setting the '.data'
field. This method creates a new tensor that links to the old tensor's
data, without linking the viewed tensor, referenced via the '._base'
field.
'''
out = torch.empty(
(1,),
dtype = inp.dtype,
device = inp.device,
requires_grad = requires_grad,
)
out.data = inp.data
return out
class MakeViewlessTensor(torch.autograd.Function):
'''
Autograd function to make a viewless tensor.
This function should be used in cases where the computation graph needs
to be propagated, but we only want a viewless tensor (e.g.,
ParallelTransformer's hidden_states). Call this function by passing
'keep_graph = True' to 'make_viewless_tensor()'.
'''
@staticmethod
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def make_viewless_tensor(inp, requires_grad, keep_graph):
'''
Entry-point for creating viewless tensors.
This method should be used, rather than calling 'MakeViewlessTensor'
or '_kernel_make_viewless_tensor' directly. This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor.
'''
# return tensor as-is, if not a 'view'
if inp._base is None:
return inp
# create viewless tensor
if keep_graph:
return MakeViewlessTensor.apply(inp, requires_grad)
else:
return _kernel_make_viewless_tensor(inp, requires_grad)
def assert_viewless_tensor(tensor, extra_msg = None):
'''Assert that a tensor is not a view (i.e., its '._base' field is
not set).'''
if isinstance(tensor, list):
[ assert_viewless_tensor(t) for t in tensor ]
return tensor
if not isinstance(tensor, torch.Tensor):
return tensor
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). %s"
) % extra_msg
return tensor
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
'''Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
'''
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
tensor.data = new_data_tensor
...@@ -4,7 +4,8 @@ import time ...@@ -4,7 +4,8 @@ import time
import numpy as np import numpy as np
import torch import torch
from megatron import get_args, get_tokenizer, mpu, print_rank_0 from megatron import get_args, get_tokenizer, print_rank_0
from megatron.core import mpu, tensor_parallel
from megatron.data.dataset_utils import create_masked_lm_predictions, \ from megatron.data.dataset_utils import create_masked_lm_predictions, \
pad_and_convert_to_numpy pad_and_convert_to_numpy
from megatron.data.data_samplers import MegatronPretrainingSampler from megatron.data.data_samplers import MegatronPretrainingSampler
...@@ -57,7 +58,7 @@ def get_ict_batch(data_iterator): ...@@ -57,7 +58,7 @@ def get_ict_batch(data_iterator):
data = None data = None
else: else:
data = next(data_iterator) data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype) data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack. # Unpack.
query_tokens = data_b['query_tokens'].long() query_tokens = data_b['query_tokens'].long()
......
...@@ -8,8 +8,6 @@ import numpy as np ...@@ -8,8 +8,6 @@ import numpy as np
import torch import torch
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import mpu
class BlendableDataset(torch.utils.data.Dataset): class BlendableDataset(torch.utils.data.Dataset):
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
import numpy as np import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron.core import mpu
def build_pretraining_data_loader(dataset, consumed_samples): def build_pretraining_data_loader(dataset, consumed_samples):
......
...@@ -28,9 +28,9 @@ import torch ...@@ -28,9 +28,9 @@ import torch
from megatron import ( from megatron import (
get_args, get_args,
mpu,
print_rank_0 print_rank_0
) )
from megatron.core import mpu
from megatron.data.blendable_dataset import BlendableDataset from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment