"docs/developer_guide/benchmark_and_profiling.md" did not exist on "4d2a88bdffe91168dfc73ef7e3bc9100ba96686b"
Commit 42c40715 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'core' into 'main'

Move most of mpu functionality into a new "Megatron core"

See merge request ADLR/megatron-lm!462
parents fdc801ed d3a416cd
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Model parallel utility interface."""
from .cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data
from .initialize import is_unitialized
from .initialize import destroy_model_parallel
from .initialize import get_data_parallel_group
from .initialize import get_data_parallel_rank
from .initialize import get_data_parallel_world_size
from .initialize import get_embedding_group
from .initialize import get_position_embedding_group
from .initialize import get_model_parallel_group
from .initialize import get_tensor_model_parallel_group
from .initialize import get_pipeline_model_parallel_group
from .initialize import get_tensor_model_parallel_rank, set_tensor_model_parallel_rank
from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank
from .initialize import is_pipeline_first_stage, is_pipeline_last_stage
from .initialize import is_rank_in_embedding_group
from .initialize import is_rank_in_position_embedding_group
from .initialize import is_pipeline_stage_before_split, is_pipeline_stage_after_split
from .initialize import is_pipeline_stage_at_split
from .initialize import get_num_layers
from .initialize import get_tensor_model_parallel_src_rank
from .initialize import get_data_parallel_src_rank
from .initialize import get_pipeline_model_parallel_first_rank
from .initialize import get_pipeline_model_parallel_last_rank
from .initialize import get_pipeline_model_parallel_next_rank
from .initialize import get_pipeline_model_parallel_prev_rank
from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size
from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size
from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pipeline_model_parallel_rank
from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized
from .layers import LinearWithGradAccumulationAndAsyncCommunication
from .layers import ColumnParallelLinear
from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding
from .layers import (set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes)
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import scatter_to_sequence_parallel_region
from .mappings import gather_from_sequence_parallel_region
from .mappings import reduce_scatter_to_sequence_parallel_region
from .random import checkpoint
from .random import get_cuda_rng_tracker
from .random import model_parallel_cuda_manual_seed
from .random import gather_split_1d_tensor
from .random import split_tensor_into_1d_equal_chunks
from .random import make_viewless_tensor
from .random import assert_viewless_tensor
from .random import safely_set_viewless_tensor_data
from .utils import divide
from .utils import split_tensor_along_last_dim
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, '{} is not divisible by {}'.format(
numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim(tensor, num_partitions,
contiguous_split_chunks=False):
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
class VocabUtility:
"""Split the vocabulary into `world_size` chunks amd return the
first and last index of the vocabulary belonging to the `rank`
partition: Note that indecies in [fist, last)"""
@staticmethod
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
rank, world_size):
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
@staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size)
...@@ -9,7 +9,7 @@ from apex.multi_tensor_apply import multi_tensor_applier ...@@ -9,7 +9,7 @@ from apex.multi_tensor_apply import multi_tensor_applier
import amp_C import amp_C
from megatron.model.module import param_is_not_shared from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
def clip_grad_norm_fp32(parameters, grads_for_norm, def clip_grad_norm_fp32(parameters, grads_for_norm,
......
...@@ -8,10 +8,9 @@ import torch ...@@ -8,10 +8,9 @@ import torch
from megatron import get_args from megatron import get_args
from megatron import get_timers from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.core import mpu, tensor_parallel
from megatron.model.module import param_is_not_shared from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
...@@ -290,9 +289,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -290,9 +289,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_model_param = model_param.detach().view(-1) \ shard_model_param = model_param.detach().view(-1) \
[param_range.start:param_range.end] [param_range.start:param_range.end]
shard_main_param = shard_model_param.clone().float() shard_main_param = shard_model_param.clone().float()
mpu.copy_tensor_model_parallel_attributes( tensor_parallel.copy_tensor_model_parallel_attributes(
shard_model_param, model_param) shard_model_param, model_param)
mpu.copy_tensor_model_parallel_attributes( tensor_parallel.copy_tensor_model_parallel_attributes(
shard_main_param, model_param) shard_main_param, model_param)
if hasattr(model_param, 'shared'): if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared shard_model_param.shared = model_param.shared
...@@ -309,7 +308,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -309,7 +308,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
[param_range.start:param_range.end] [param_range.start:param_range.end]
model_fp32_params_this_group.append(model_param) model_fp32_params_this_group.append(model_param)
shard_fp32_params_this_group.append(shard_model_param) shard_fp32_params_this_group.append(shard_model_param)
mpu.copy_tensor_model_parallel_attributes( tensor_parallel.copy_tensor_model_parallel_attributes(
shard_model_param, model_param) shard_model_param, model_param)
if hasattr(model_param, 'shared'): if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared shard_model_param.shared = model_param.shared
......
...@@ -11,12 +11,11 @@ from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP ...@@ -11,12 +11,11 @@ from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from megatron import get_timers from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.core import mpu, tensor_parallel
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module from megatron.model import Float16Module
from megatron.model.module import param_is_not_shared from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
from megatron.utils import unwrap_model from megatron.utils import unwrap_model
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32 from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
...@@ -102,7 +101,7 @@ class MegatronOptimizer(ABC): ...@@ -102,7 +101,7 @@ class MegatronOptimizer(ABC):
grad = param.grad grad = param.grad
grad_not_none = grad is not None grad_not_none = grad is not None
is_not_shared = param_is_not_shared(param) is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) is_not_tp_duplicate = tensor_parallel.param_is_not_tensor_parallel_duplicate(param)
if grad_not_none and is_not_shared and is_not_tp_duplicate: if grad_not_none and is_not_shared and is_not_tp_duplicate:
grads_for_norm.append(grad) grads_for_norm.append(grad)
...@@ -529,7 +528,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): ...@@ -529,7 +528,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
# Create a copy # Create a copy
main_param = param.detach().clone().float() main_param = param.detach().clone().float()
# Copy tensor model parallel attributes. # Copy tensor model parallel attributes.
mpu.copy_tensor_model_parallel_attributes(main_param, tensor_parallel.copy_tensor_model_parallel_attributes(main_param,
param) param)
if hasattr(param, 'shared'): if hasattr(param, 'shared'):
main_param.shared = param.shared main_param.shared = param.shared
......
...@@ -4,8 +4,8 @@ from functools import reduce ...@@ -4,8 +4,8 @@ from functools import reduce
import operator import operator
import torch import torch
from megatron import get_args from megatron import get_args, core
from megatron import mpu from megatron.core import mpu
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...@@ -81,10 +81,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -81,10 +81,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
args.scatter_gather_tensors_in_pipeline and \ args.scatter_gather_tensors_in_pipeline and \
not args.sequence_parallel: not args.sequence_parallel:
if tensor_send_next is not None: if tensor_send_next is not None:
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next) tensor_send_next = core.tensor_parallel.split_tensor_into_1d_equal_chunks(tensor_send_next)
if tensor_send_prev is not None: if tensor_send_prev is not None:
tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev) tensor_send_prev = core.tensor_parallel.split_tensor_into_1d_equal_chunks(tensor_send_prev)
# Send tensors in both the forward and backward directions as appropriate. # Send tensors in both the forward and backward directions as appropriate.
if args.use_ring_exchange_p2p: if args.use_ring_exchange_p2p:
...@@ -127,16 +127,16 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -127,16 +127,16 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
args.scatter_gather_tensors_in_pipeline and \ args.scatter_gather_tensors_in_pipeline and \
not args.sequence_parallel: not args.sequence_parallel:
if recv_prev: if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor( tensor_recv_prev = core.tensor_parallel.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_() tensor_recv_prev).view(tensor_shape).requires_grad_()
tensor_recv_prev = mpu.make_viewless_tensor(tensor_recv_prev, tensor_recv_prev = core.utils.make_viewless_tensor(tensor_recv_prev,
requires_grad = True, requires_grad = True,
keep_graph = False) keep_graph = False)
if recv_next: if recv_next:
tensor_recv_next = mpu.gather_split_1d_tensor( tensor_recv_next = core.tensor_parallel.gather_split_1d_tensor(
tensor_recv_next).view(tensor_shape).requires_grad_() tensor_recv_next).view(tensor_shape).requires_grad_()
tensor_recv_next = mpu.make_viewless_tensor(tensor_recv_next, tensor_recv_next = core.utils.make_viewless_tensor(tensor_recv_next,
requires_grad = True, requires_grad = True,
keep_graph = False) keep_graph = False)
......
...@@ -8,8 +8,8 @@ from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP ...@@ -8,8 +8,8 @@ from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args from megatron import get_args
from megatron import get_num_microbatches from megatron import get_num_microbatches
from megatron import get_timers from megatron import get_timers
from megatron import mpu
from megatron import p2p_communication from megatron import p2p_communication
from megatron.core import mpu
from megatron.utils import unwrap_model from megatron.utils import unwrap_model
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module from megatron.model import Float16Module
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import torch import torch
from megatron import mpu from megatron.core import mpu
from .communication import broadcast_float_list from .communication import broadcast_float_list
from .generation import ( from .generation import (
generate_tokens_probs_and_return_on_first_stage, generate_tokens_probs_and_return_on_first_stage,
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import torch import torch
from megatron import mpu from megatron.core import mpu
......
...@@ -6,9 +6,8 @@ from collections.abc import Iterable ...@@ -6,9 +6,8 @@ from collections.abc import Iterable
import torch import torch
from megatron import ( from megatron import get_args
get_args, from megatron.core import mpu
mpu)
from .communication import ( from .communication import (
send_to_next_pipeline_rank, send_to_next_pipeline_rank,
recv_from_prev_pipeline_rank_) recv_from_prev_pipeline_rank_)
......
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args, get_tokenizer, mpu from megatron import get_args, get_tokenizer
from megatron.core import mpu
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from .communication import ( from .communication import (
copy_from_last_to_first_pipeline_stage, copy_from_last_to_first_pipeline_stage,
......
...@@ -19,7 +19,7 @@ from megatron import get_current_global_batch_size ...@@ -19,7 +19,7 @@ from megatron import get_current_global_batch_size
from megatron import get_num_microbatches from megatron import get_num_microbatches
from megatron import is_last_rank from megatron import is_last_rank
from megatron import update_num_microbatches from megatron import update_num_microbatches
from megatron import mpu from megatron.core import mpu, tensor_parallel
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import print_rank_last from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
...@@ -257,7 +257,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ...@@ -257,7 +257,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
# are set for all params so the optimizer can use them. # are set for all params so the optimizer can use them.
for model_module in model: for model_module in model:
for param in model_module.parameters(): for param in model_module.parameters():
mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param) tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters. # Print number of parameters.
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
......
...@@ -10,11 +10,13 @@ from torch.nn.parallel import DistributedDataParallel as torchDDP ...@@ -10,11 +10,13 @@ from torch.nn.parallel import DistributedDataParallel as torchDDP
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
import amp_C import amp_C
from megatron import get_args from megatron import (
from megatron import get_adlr_autoresume get_args,
from megatron import mpu get_adlr_autoresume,
)
from megatron.core import mpu
from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
from megatron.model.module import param_is_not_shared from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
def unwrap_model(model, module_instances=(torchDDP)): def unwrap_model(model, module_instances=(torchDDP)):
......
...@@ -10,7 +10,7 @@ import torch.nn.functional as F ...@@ -10,7 +10,7 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron.core import tensor_parallel
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel, ModelType from megatron.model import BertModel, ModelType
from megatron.training import pretrain from megatron.training import pretrain
...@@ -46,7 +46,7 @@ def get_batch(data_iterator): ...@@ -46,7 +46,7 @@ def get_batch(data_iterator):
data = next(data_iterator) data = next(data_iterator)
else: else:
data = None data = None
data_b = mpu.broadcast_data(keys, data, datatype) data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack. # Unpack.
tokens = data_b['text'].long() tokens = data_b['text'].long()
......
...@@ -8,7 +8,7 @@ from megatron import get_args ...@@ -8,7 +8,7 @@ from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron.core import tensor_parallel
from megatron.data.gpt_dataset import build_train_valid_test_datasets from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel, ModelType from megatron.model import GPTModel, ModelType
from megatron.training import pretrain from megatron.training import pretrain
...@@ -42,7 +42,7 @@ def get_batch(data_iterator): ...@@ -42,7 +42,7 @@ def get_batch(data_iterator):
data = next(data_iterator) data = next(data_iterator)
else: else:
data = None data = None
data_b = mpu.broadcast_data(keys, data, datatype) data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack. # Unpack.
tokens_ = data_b['text'].long() tokens_ = data_b['text'].long()
......
...@@ -12,7 +12,7 @@ import torch.nn.functional as F ...@@ -12,7 +12,7 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron.core import mpu
from megatron.data.biencoder_dataset_utils import get_ict_batch from megatron.data.biencoder_dataset_utils import get_ict_batch
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import ModelType from megatron.model import ModelType
......
...@@ -9,9 +9,9 @@ import torch ...@@ -9,9 +9,9 @@ import torch
from megatron import ( from megatron import (
get_args, get_args,
get_timers, get_timers,
mpu,
print_rank_0 print_rank_0
) )
from megatron.core import tensor_parallel
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import T5Model, ModelType from megatron.model import T5Model, ModelType
from megatron.training import pretrain from megatron.training import pretrain
...@@ -80,7 +80,7 @@ def get_batch(data_iterator): ...@@ -80,7 +80,7 @@ def get_batch(data_iterator):
data = next(data_iterator) data = next(data_iterator)
else: else:
data = None data = None
data_b = mpu.broadcast_data(keys, data, datatype) data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack. # Unpack.
tokens_enc = data_b['text_enc'].long() tokens_enc = data_b['text_enc'].long()
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from functools import partial from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0 from megatron import get_args, get_timers, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model import ModelType from megatron.model import ModelType
from megatron.model.vision.classification import VitClassificationModel from megatron.model.vision.classification import VitClassificationModel
......
...@@ -6,7 +6,7 @@ import torch.nn as nn ...@@ -6,7 +6,7 @@ import torch.nn as nn
import numpy as np import numpy as np
import torch.distributed as dist import torch.distributed as dist
from functools import partial from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0 from megatron import get_args, get_timers, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.dino import DINOPretrainModel from megatron.model.vision.dino import DINOPretrainModel
from megatron.model.vision.knn_monitor import knn_predict, get_feature_bank from megatron.model.vision.knn_monitor import knn_predict, get_feature_bank
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from functools import partial from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0, print_rank_last from megatron import get_args, get_timers, print_rank_0, print_rank_last
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.inpainting import VitInpaintingModel from megatron.model.vision.inpainting import VitInpaintingModel
from megatron.model.vision.inpainting import MitInpaintingModel from megatron.model.vision.inpainting import MitInpaintingModel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment