Commit 42c40715 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'core' into 'main'

Move most of mpu functionality into a new "Megatron core"

See merge request ADLR/megatron-lm!462
parents fdc801ed d3a416cd
[html]
directory = coverage
[run]
data_file = .coverage_$LOCAL_RANK
__pycache__
*.so
build
.coverage_*
*.egg-info
image: gitlab-master.nvidia.com/dl/dgx/pytorch:21.12-py3-devel
test:
tags:
- docker_gpu_enabled
script:
- pytest --junitxml=report.xml tests
- torchrun --nproc_per_node=8 -m pytest --cov-report=term --cov-report=html --cov=megatron/core tests/
coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/'
artifacts:
when: always
reports:
junit: report.xml
paths:
- coverage
expire_in: 30 days
\ No newline at end of file
......@@ -10,7 +10,6 @@ from .global_vars import get_tokenizer
from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume
from .global_vars import get_timers
from .global_vars import get_global_memory_buffer
from .initialize import initialize_megatron
from .utils import (print_rank_0,
......
......@@ -168,14 +168,6 @@ def validate_args(args, defaults={}):
if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local'
assert args.use_contiguous_buffers_in_local_ddp
else:
if args.gradient_accumulation_fusion:
args.gradient_accumulation_fusion = False
if args.rank == 0:
print('Gradient accumulation fusion to linear layer weight '
'gradient computation is supported only with fp32 '
'gradient accumulation. Setting gradient_accumulation_fusion '
'to False', flush=True)
# If we use the distributed optimizer, we need to have local DDP
# and we should make sure use-contiguous-buffers-in-local-ddp is on.
......@@ -321,6 +313,18 @@ def validate_args(args, defaults={}):
if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if args.sequence_parallel:
raise RuntimeError(
"Using sequence parallelism requires setting the environment variable "
"CUDA_DEVICE_MAX_CONNECTIONS to 1")
if args.async_tensor_model_parallel_allreduce:
raise RuntimeError(
"Using async gradient all reduce requires setting the environment "
"variable CUDA_DEVICE_MAX_CONNECTIONS to 1")
_print_args(args)
return args
......
......@@ -9,8 +9,8 @@ import numpy as np
import torch
from megatron import (mpu,
update_num_microbatches)
from megatron import update_num_microbatches
from megatron.core import mpu, tensor_parallel
from .global_vars import get_args
from .utils import (unwrap_model,
print_rank_0)
......@@ -185,7 +185,7 @@ def get_rng_state():
'np_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state(),
'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states()}
'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()}
rng_state_list = None
if torch.distributed.is_initialized() and \
......@@ -590,7 +590,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array
if not rng_state['rng_tracker_states']:
raise KeyError
mpu.get_cuda_rng_tracker().set_states(
tensor_parallel.get_cuda_rng_tracker().set_states(
rng_state['rng_tracker_states'])
else: # backward compatability
random.setstate(model_state_dict['random_rng_state'])
......@@ -600,7 +600,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array
if not model_state_dict['rng_tracker_states']:
raise KeyError
mpu.get_cuda_rng_tracker().set_states(
tensor_parallel.get_cuda_rng_tracker().set_states(
model_state_dict['rng_tracker_states'])
except KeyError:
print_rank_0('Unable to load rng state from checkpoint {}. '
......
import megatron.core.parallel_state
import megatron.core.tensor_parallel
import megatron.core.utils
# Alias parallel_state as mpu, its legacy name
mpu = parallel_state
__all__ = [
"parallel_state",
"tensor_parallel",
"utils",
]
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Model and data parallel groups."""
import torch
from typing import Optional
from .utils import ensure_divisibility
from .utils import GlobalMemoryBuffer
# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
......@@ -45,17 +44,16 @@ _PIPELINE_GLOBAL_RANKS = None
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None
# Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER = None
def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return _DATA_PARALLEL_GROUP is None
def initialize_model_parallel(tensor_model_parallel_size_=1,
pipeline_model_parallel_size_=1,
virtual_pipeline_model_parallel_size_=None,
pipeline_model_parallel_split_rank_=None):
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: Optional[int] = None,
pipeline_model_parallel_split_rank: Optional[int] = None,
) -> None:
"""
Initialize model data parallel groups.
......@@ -67,7 +65,6 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
pipeline_model_parallel_split_rank: for models with both encoder and decoder,
rank in pipeline with split point.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
......@@ -84,49 +81,48 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
if torch.distributed.get_rank() == 0:
print('> initializing tensor model parallel with size {}'.format(
tensor_model_parallel_size_))
print('> initializing pipeline model parallel with size {}'.format(
pipeline_model_parallel_size_))
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)
ensure_divisibility(world_size,
tensor_model_parallel_size * pipeline_model_parallel_size)
data_parallel_size = world_size // (tensor_model_parallel_size *
pipeline_model_parallel_size)
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
num_data_parallel_groups = world_size // data_parallel_size
if virtual_pipeline_model_parallel_size_ is not None:
world_size: int = torch.distributed.get_world_size()
if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:
raise RuntimeError(
f"world_size ({world_size}) is not divisible by tensor_model_parallel_size "
f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})"
)
data_parallel_size: int = world_size // (tensor_model_parallel_size *
pipeline_model_parallel_size)
num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
num_data_parallel_groups: int = world_size // data_parallel_size
if virtual_pipeline_model_parallel_size is not None:
if not pipeline_model_parallel_size > 2:
raise RuntimeError("pipeline-model-parallel size should be greater than 2 with "
"interleaved schedule")
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size
if pipeline_model_parallel_split_rank_ is not None:
if pipeline_model_parallel_split_rank is not None:
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank_
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank
rank = torch.distributed.get_rank()
# Build the data-parallel groups.
global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_GLOBAL_RANKS
assert _DATA_PARALLEL_GROUP is None, \
'data parallel group is already initialized'
assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'
all_data_parallel_group_ranks = []
for i in range(pipeline_model_parallel_size):
start_rank = i * num_pipeline_model_parallel_groups
end_rank = (i + 1) * num_pipeline_model_parallel_groups
for j in range(tensor_model_parallel_size):
ranks = range(start_rank + j, end_rank,
tensor_model_parallel_size)
ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
all_data_parallel_group_ranks.append(list(ranks))
group = torch.distributed.new_group(ranks)
if rank in ranks:
......@@ -135,8 +131,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# Build the model-parallel groups.
global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, \
'model parallel group is already initialized'
assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'
for i in range(data_parallel_size):
ranks = [data_parallel_group_ranks[i]
for data_parallel_group_ranks in all_data_parallel_group_ranks]
......@@ -163,15 +158,13 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
'pipeline model parallel group is already initialized'
global _EMBEDDING_GROUP
global _EMBEDDING_GLOBAL_RANKS
assert _EMBEDDING_GROUP is None, \
'embedding group is already initialized'
assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'
global _POSITION_EMBEDDING_GROUP
global _POSITION_EMBEDDING_GLOBAL_RANKS
assert _POSITION_EMBEDDING_GROUP is None, \
'position embedding group is already initialized'
for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size,
num_pipeline_model_parallel_groups)
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group
......@@ -181,14 +174,14 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]]
position_embedding_ranks = [ranks[0]]
if pipeline_model_parallel_split_rank_ is not None:
if ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks:
if pipeline_model_parallel_split_rank is not None:
if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:
embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank_],
ranks[pipeline_model_parallel_split_rank],
ranks[-1]]
if ranks[pipeline_model_parallel_split_rank_] not in position_embedding_ranks:
if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:
position_embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank_]]
ranks[pipeline_model_parallel_split_rank]]
else:
embedding_ranks = ranks
position_embedding_ranks = ranks
......@@ -205,6 +198,12 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
# Initialize global memory buffer
# This isn't really "parallel state" but there isn't another good place to
# put this. If we end up with a more generic initialization of megatron-core
# we could stick it there
_set_global_memory_buffer()
def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
......@@ -297,6 +296,12 @@ def set_pipeline_model_parallel_rank(rank):
_MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
def set_pipeline_model_parallel_split_rank(rank):
"""Set pipeline model parallel split rank."""
global _MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK
......@@ -313,57 +318,6 @@ def get_pipeline_model_parallel_rank():
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False):
"""Compute the number of transformer layers resident on the current rank."""
if get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model:
assert args.pipeline_model_parallel_split_rank is not None
# When a standalone embedding stage is used, a rank is taken from
# the encoder's ranks, to be used for the encoder's embedding
# layer. This way, the rank referenced by the 'split rank' remains
# the same whether or not a standalone embedding stage is used.
num_ranks_in_encoder = (
args.pipeline_model_parallel_split_rank - 1
if args.standalone_embedding_stage else
args.pipeline_model_parallel_split_rank
)
num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
assert args.encoder_num_layers % num_ranks_in_encoder == 0, \
'encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder)
assert args.decoder_num_layers % num_ranks_in_decoder == 0, \
'decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder)
if is_pipeline_stage_before_split():
num_layers = (
0
if args.standalone_embedding_stage
and get_pipeline_model_parallel_rank() == 0 else
args.encoder_num_layers // num_ranks_in_encoder
)
else:
num_layers = args.decoder_num_layers // num_ranks_in_decoder
else:
assert args.num_layers == args.encoder_num_layers
assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'num_layers must be divisible by transformer_pipeline_model_parallel_size'
# When a standalone embedding stage is used, all transformer layers
# are divided among pipeline rank >= 1, while on pipeline rank 0,
# ranks either contain the input embedding layer (virtual pp rank 0),
# or no layers at all (virtual pp rank >= 1).
num_layers = (
0
if args.standalone_embedding_stage
and get_pipeline_model_parallel_rank() == 0 else
args.num_layers // args.transformer_pipeline_model_parallel_size
)
else:
if not is_decoder:
num_layers = args.encoder_num_layers
else:
num_layers = args.decoder_num_layers
return num_layers
def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
......@@ -484,18 +438,23 @@ def get_data_parallel_src_rank():
def get_pipeline_model_parallel_first_rank():
"""Return the global rank of the first process in the pipeline for the
current tensor parallel group"""
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
return _PIPELINE_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_last_rank():
"""Return the global rank of the last process in the pipeline for the
current tensor parallel group"""
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_next_rank():
"""Return the global rank that follows the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
rank_in_pipeline = get_pipeline_model_parallel_rank()
......@@ -504,6 +463,7 @@ def get_pipeline_model_parallel_next_rank():
def get_pipeline_model_parallel_prev_rank():
"""Return the global rank that preceeds the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
rank_in_pipeline = get_pipeline_model_parallel_rank()
......@@ -520,6 +480,17 @@ def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return torch.distributed.get_rank(group=get_data_parallel_group())
def _set_global_memory_buffer():
"""Initialize global buffer"""
global _GLOBAL_MEMORY_BUFFER
assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized'
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
def get_global_memory_buffer():
"""Return the global GlobalMemoryBuffer object"""
assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
return _GLOBAL_MEMORY_BUFFER
def destroy_model_parallel():
"""Set the groups to none."""
......@@ -535,3 +506,17 @@ def destroy_model_parallel():
_EMBEDDING_GROUP = None
global _POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
global _GLOBAL_MEMORY_BUFFER
_GLOBAL_MEMORY_BUFFER = None
from .cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data
from .layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes,
param_is_not_tensor_parallel_duplicate,
linear_with_grad_accumulation_and_async_allreduce
)
from .mappings import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
scatter_to_tensor_model_parallel_region,
scatter_to_sequence_parallel_region,
)
from .random import (
checkpoint,
get_cuda_rng_tracker,
model_parallel_cuda_manual_seed,
)
from .utils import (
split_tensor_along_last_dim,
split_tensor_into_1d_equal_chunks,
gather_split_1d_tensor,
)
__all__ = [
# cross_entropy.py
"vocab_parallel_cross_entropy",
# data.py
"broadcast_data",
#layers.py
"ColumnParallelLinear",
"RowParallelLinear",
"VocabParallelEmbedding",
"set_tensor_model_parallel_attributes",
"set_defaults_if_not_set_tensor_model_parallel_attributes",
"copy_tensor_model_parallel_attributes",
"param_is_not_tensor_parallel_duplicate",
"linear_with_grad_accumulation_and_async_allreduce",
# mappings.py
"copy_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region",
"gather_from_sequence_parallel_region",
# "reduce_from_tensor_model_parallel_region",
"scatter_to_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region",
# random.py
"checkpoint",
"get_cuda_rng_tracker",
"model_parallel_cuda_manual_seed",
# utils.py
"split_tensor_along_last_dim",
"split_tensor_into_1d_equal_chunks",
"gather_split_1d_tensor",
]
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
from .initialize import get_tensor_model_parallel_group
from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size
from megatron.core.parallel_state import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
)
from .utils import VocabUtility
class _VocabParallelCrossEntropy(torch.autograd.Function):
@staticmethod
def forward(ctx, vocab_parallel_logits, target):
def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0):
# Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
......@@ -20,7 +22,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
op=torch.distributed.ReduceOp.MAX,
group=get_tensor_model_parallel_group())
# Subtract the maximum value.
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1)
# Get the partition's vocab indecies
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
......@@ -62,8 +64,32 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
# Store softmax, target-mask and masked-target for backward pass.
# Normalize and optionally smooth logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
vocab_size = exp_logits.size(-1)
if label_smoothing > 0:
"""
We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth.
= (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt})
= (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
= ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
= (K * (1 - alpha) - 1) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i} y_i
= (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K
From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py
"""
assert 1.0 > label_smoothing > 0.0
smoothing = label_smoothing * vocab_size / (vocab_size - 1)
# Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs.
log_probs = torch.log(exp_logits)
mean_log_probs = log_probs.mean(dim=-1)
loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs
ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
# Store softmax, target-mask and masked-target for backward pass.
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss
......@@ -73,6 +99,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size
# All the inputs have softmax as thier gradient.
grad_input = softmax
......@@ -83,15 +110,34 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0],
device=grad_2d.device)
grad_2d[arange_1d, masked_target_1d] -= (
1.0 - target_mask.view(-1).float())
softmax_update = 1.0 - target_mask.view(-1).float()
if label_smoothing > 0:
smoothing = label_smoothing * vocab_size / (vocab_size - 1)
grad_2d[arange_1d, masked_target_1d] -= (1.0 - smoothing) * softmax_update
average_grad = 1 / vocab_size
grad_2d[arange_1d, :] -= smoothing * average_grad
else:
grad_2d[arange_1d, masked_target_1d] -= softmax_update
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input, None
return grad_input, None, None
def vocab_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0):
"""
Performs cross entropy loss when logits are split across tensor parallel ranks
Arguments:
vocab_parallel_logits: logits split across tensor parallel ranks
dimension is [sequence_length, batch_size, hidden_size]
target: correct vocab ids of dimseion [sequence_length, micro_batch_size]
def vocab_parallel_cross_entropy(vocab_parallel_logits, target):
"""Helper function for the cross entropy."""
return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target)
lobal_smoothing: smoothing factor, must be in range [0.0, 1.0)
default is no smoothing (=0.0)
"""
return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target, label_smoothing)
......@@ -2,9 +2,11 @@
import torch
from .initialize import get_tensor_model_parallel_group
from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_src_rank
from megatron.core.parallel_state import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_src_rank,
)
_MAX_DATA_DIM = 5
......
......@@ -2,7 +2,11 @@
import torch
from .initialize import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank
from megatron.core.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group,
)
from .utils import split_tensor_along_last_dim
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
......@@ -11,13 +10,19 @@ from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager
from torch.utils.checkpoint import detach_variable
from megatron.memory import allocate_mem_buff
from megatron.core.parallel_state import (
get_data_parallel_rank,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from .initialize import get_data_parallel_rank
from .initialize import get_tensor_model_parallel_group
from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size
from .utils import (
split_tensor_into_1d_equal_chunks,
gather_split_1d_tensor,
)
from megatron.core.utils import safely_set_viewless_tensor_data
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
......@@ -56,117 +61,6 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call(cb)
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
"""Break a tensor into equal 1D chunks."""
partition_size = torch.numel(tensor) // \
get_tensor_model_parallel_world_size()
start_index = partition_size * get_tensor_model_parallel_rank()
end_index = start_index + partition_size
if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data
def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
numel_gathered = torch.numel(tensor) * \
get_tensor_model_parallel_world_size()
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# TODO: This API is experimental in pytorch (as of Feb 2022) and
# this might break in future pytorch releases. We chose this API
# as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch.distributed._all_gather_base(gathered, tensor,
group=get_tensor_model_parallel_group())
return gathered
def _kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor.
View tensors have the undesirable side-affect of retaining a reference
to the originally-viewed tensor, even after manually setting the '.data'
field. This method creates a new tensor that links to the old tensor's
data, without linking the viewed tensor, referenced via the '._base'
field.
'''
out = torch.empty(
(1,),
dtype = inp.dtype,
device = inp.device,
requires_grad = requires_grad,
)
out.data = inp.data
return out
class MakeViewlessTensor(torch.autograd.Function):
'''
Autograd function to make a viewless tensor.
This function should be used in cases where the computation graph needs
to be propagated, but we only want a viewless tensor (e.g.,
ParallelTransformer's hidden_states). Call this function by passing
'keep_graph = True' to 'make_viewless_tensor()'.
'''
@staticmethod
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def make_viewless_tensor(inp, requires_grad, keep_graph):
'''
Entry-point for creating viewless tensors.
This method should be used, rather than calling 'MakeViewlessTensor'
or '_kernel_make_viewless_tensor' directly. This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor.
'''
# return tensor as-is, if not a 'view'
if inp._base is None:
return inp
# create viewless tensor
if keep_graph:
return MakeViewlessTensor.apply(inp, requires_grad)
else:
return _kernel_make_viewless_tensor(inp, requires_grad)
def assert_viewless_tensor(tensor, extra_msg = None):
'''Assert that a tensor is not a view (i.e., its '._base' field is
not set).'''
if isinstance(tensor, list):
[ assert_viewless_tensor(t) for t in tensor ]
return tensor
if not isinstance(tensor, torch.Tensor):
return tensor
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). %s"
) % extra_msg
return tensor
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
'''Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
'''
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
tensor.data = new_data_tensor
class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
......@@ -271,13 +165,6 @@ def model_parallel_cuda_manual_seed(seed):
# Data parallel gets the original seed.
data_parallel_seed = seed
if torch.distributed.get_rank() == 0:
print('> initializing model parallel cuda seeds on global rank {}, '
'model parallel rank {}, and data parallel rank {} with '
'model parallel seed: {} and data parallel seed: {}'.format(
torch.distributed.get_rank(), get_tensor_model_parallel_rank(),
get_data_parallel_rank(), tensor_model_parallel_seed,
data_parallel_seed), flush=True)
_CUDA_RNG_STATE_TRACKER.reset()
# Set the default state.
torch.cuda.manual_seed(data_parallel_seed)
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
from typing import List, Sequence
from megatron.core.utils import divide
from megatron.core import parallel_state
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]:
""" Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
""" Break a tensor into equal 1D chunks across tensor parallel ranks.
Returns a Tensor or View with this rank's portion of the data.
Arguments:
tensor: The tensor to split
Keyword Arguments:
new_buffer (bool): If True, returns a new Tensor.
If False, returns a view into the existing Tensor.
Default is False
"""
partition_size = torch.numel(tensor) // \
parallel_state.get_tensor_model_parallel_world_size()
start_index = partition_size * parallel_state.get_tensor_model_parallel_rank()
end_index = start_index + partition_size
if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data
def gather_split_1d_tensor(tensor):
""" Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor
model parallel ranks.
Returns a new Tensor with the gathered data.
Arguments:
tensor: A Tensor or view of this rank's portion of the data.
"""
numel_gathered = torch.numel(tensor) * \
parallel_state.get_tensor_model_parallel_world_size()
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# TODO: This API is experimental in pytorch (as of Feb 2022) and
# this might break in future pytorch releases. We chose this API
# as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch.distributed._all_gather_base(gathered, tensor,
group=parallel_state.get_tensor_model_parallel_group())
return gathered
class VocabUtility:
""" Split the vocabulary into `world_size` chunks and return the first
and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [fist, last)
"""
@staticmethod
def vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size: int, rank, world_size: int
) -> Sequence[int]:
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
@staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size
)
"""Utility functions used throughout Megatron core"""
from functools import reduce
import operator
import torch
from megatron.core import parallel_state
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator
)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
are not used concurrently."""
def __init__(self):
self.buffer = {}
def get_tensor(self, tensor_shape, dtype, name):
required_len = reduce(operator.mul, tensor_shape, 1)
if self.buffer.get((name, dtype), None) is None or \
self.buffer[(name, dtype)].numel() < required_len:
self.buffer[(name, dtype)] = \
torch.empty(required_len,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False)
return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
def _kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor.
View tensors have the undesirable side-affect of retaining a reference
to the originally-viewed tensor, even after manually setting the '.data'
field. This method creates a new tensor that links to the old tensor's
data, without linking the viewed tensor, referenced via the '._base'
field.
'''
out = torch.empty(
(1,),
dtype = inp.dtype,
device = inp.device,
requires_grad = requires_grad,
)
out.data = inp.data
return out
class MakeViewlessTensor(torch.autograd.Function):
'''
Autograd function to make a viewless tensor.
This function should be used in cases where the computation graph needs
to be propagated, but we only want a viewless tensor (e.g.,
ParallelTransformer's hidden_states). Call this function by passing
'keep_graph = True' to 'make_viewless_tensor()'.
'''
@staticmethod
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def make_viewless_tensor(inp, requires_grad, keep_graph):
'''
Entry-point for creating viewless tensors.
This method should be used, rather than calling 'MakeViewlessTensor'
or '_kernel_make_viewless_tensor' directly. This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor.
'''
# return tensor as-is, if not a 'view'
if inp._base is None:
return inp
# create viewless tensor
if keep_graph:
return MakeViewlessTensor.apply(inp, requires_grad)
else:
return _kernel_make_viewless_tensor(inp, requires_grad)
def assert_viewless_tensor(tensor, extra_msg = None):
'''Assert that a tensor is not a view (i.e., its '._base' field is
not set).'''
if isinstance(tensor, list):
[ assert_viewless_tensor(t) for t in tensor ]
return tensor
if not isinstance(tensor, torch.Tensor):
return tensor
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). %s"
) % extra_msg
return tensor
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
'''Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
'''
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
tensor.data = new_data_tensor
......@@ -4,7 +4,8 @@ import time
import numpy as np
import torch
from megatron import get_args, get_tokenizer, mpu, print_rank_0
from megatron import get_args, get_tokenizer, print_rank_0
from megatron.core import mpu, tensor_parallel
from megatron.data.dataset_utils import create_masked_lm_predictions, \
pad_and_convert_to_numpy
from megatron.data.data_samplers import MegatronPretrainingSampler
......@@ -57,7 +58,7 @@ def get_ict_batch(data_iterator):
data = None
else:
data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
query_tokens = data_b['query_tokens'].long()
......
......@@ -8,8 +8,6 @@ import numpy as np
import torch
from megatron import print_rank_0
from megatron import mpu
class BlendableDataset(torch.utils.data.Dataset):
......
......@@ -8,7 +8,7 @@ import torch
import numpy as np
from torch.utils.data import Dataset
from megatron import get_args
from megatron import mpu
from megatron.core import mpu
def build_pretraining_data_loader(dataset, consumed_samples):
......
......@@ -28,9 +28,9 @@ import torch
from megatron import (
get_args,
mpu,
print_rank_0
)
from megatron.core import mpu
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment