Commit 42c40715 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'core' into 'main'

Move most of mpu functionality into a new "Megatron core"

See merge request ADLR/megatron-lm!462
parents fdc801ed d3a416cd
......@@ -8,7 +8,8 @@ import time
import numpy as np
import torch
from megatron import mpu, print_rank_0
from megatron import print_rank_0
from megatron.core import mpu
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.dataset_utils import get_datasets_weights_and_num_samples
from megatron.data.dataset_utils import get_train_valid_test_split_
......
......@@ -9,7 +9,8 @@ import random
import torch
from torch.utils.data import Dataset
from megatron import print_rank_0, get_args, get_tokenizer, mpu
from megatron import print_rank_0, get_args, get_tokenizer
from megatron.core import tensor_parallel
from megatron.data.biencoder_dataset_utils import make_attention_mask
def get_open_retrieval_wiki_dataset():
......@@ -32,7 +33,7 @@ def get_open_retrieval_batch(data_iterator):
# Broadcast data.
data = None if data_iterator is None else next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
row_id = data_b['row_id'].long()
......
......@@ -4,9 +4,10 @@ import time
import numpy as np
import torch
from megatron import mpu, print_rank_0
from megatron import print_rank_0
from megatron.core import mpu, tensor_parallel
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
from megatron import get_args, get_tokenizer, print_rank_0, mpu
from megatron import get_args, get_tokenizer, print_rank_0
def get_one_epoch_dataloader(dataset, micro_batch_size=None):
......@@ -47,7 +48,7 @@ def get_ict_batch(data_iterator):
data = None
else:
data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
query_tokens = data_b['query_tokens'].long()
......
......@@ -7,7 +7,7 @@ import numpy as np
import torch
from megatron import get_args
from megatron import mpu
from megatron.core import mpu
def detach(tensor):
......@@ -50,10 +50,10 @@ class OpenRetreivalDataStore(object):
def load_from_file(self):
"""Populate members from instance saved to file"""
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Unpickling BlockData", flush=True)
state_dict = pickle.load(open(self.embedding_path, 'rb'))
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print(">> Finished unpickling BlockData\n", flush=True)
self.embed_data = state_dict['embed_data']
......@@ -137,7 +137,7 @@ class FaissMIPSIndex(object):
except ImportError:
raise Exception("Error: Please install faiss to use FaissMIPSIndex")
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Building index", flush=True)
cpu_index = faiss.IndexFlatIP(self.embed_size)
......@@ -149,12 +149,12 @@ class FaissMIPSIndex(object):
config.useFloat16 = True
gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config)
self.mips_index = faiss.IndexIDMap(gpu_index)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on GPU", flush=True)
else:
# CPU index supports IDs so wrap with IDMap
self.mips_index = faiss.IndexIDMap(cpu_index)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on CPU", flush=True)
# if we were constructed with a BlockData, then automatically load it
......@@ -199,7 +199,7 @@ class FaissMIPSIndex(object):
self.mips_index.add_with_ids(embeds_arr, indices_arr)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print(">>> Finished adding block data to index", flush=True)
def search_mips_index(self, query_embeds, top_k, reconstruct=True):
......
......@@ -4,8 +4,6 @@
import os
import sys
from functools import reduce
import operator
import torch
from megatron import dist_signal_handler
......@@ -20,7 +18,6 @@ _GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None
_GLOBAL_SIGNAL_HANDLER = None
_GLOBAL_MEMORY_BUFFER = None
def get_args():
"""Return arguments."""
......@@ -70,11 +67,6 @@ def get_signal_handler():
return _GLOBAL_SIGNAL_HANDLER
def get_global_memory_buffer():
_ensure_var_is_initialized(_GLOBAL_MEMORY_BUFFER, 'global memory buffer')
return _GLOBAL_MEMORY_BUFFER
def _set_signal_handler():
global _GLOBAL_SIGNAL_HANDLER
_ensure_var_is_not_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler')
......@@ -96,7 +88,6 @@ def set_global_variables(args):
_set_tensorboard_writer(args)
_set_adlr_autoresume(args)
_set_timers(args)
_set_global_memory_buffer()
if args.exit_signal_handler:
_set_signal_handler()
......@@ -176,13 +167,6 @@ def _set_timers(args):
_GLOBAL_TIMERS = Timers(args.timing_log_level, args.timing_log_option)
def _set_global_memory_buffer():
"""Initialize global buffer"""
global _GLOBAL_MEMORY_BUFFER
_ensure_var_is_not_initialized(_GLOBAL_MEMORY_BUFFER, 'global memory buffer')
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
def _ensure_var_is_initialized(var, name):
"""Make sure the input variable is not None."""
assert var is not None, '{} is not initialized.'.format(name)
......@@ -194,22 +178,3 @@ def _ensure_var_is_not_initialized(var, name):
class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
are not used concurrently."""
def __init__(self):
self.buffer = {}
def get_tensor(self, tensor_shape, dtype, name):
required_len = reduce(operator.mul, tensor_shape, 1)
if self.buffer.get((name, dtype), None) is None or \
self.buffer[(name, dtype)].numel() < required_len:
self.buffer[(name, dtype)] = \
torch.empty(required_len,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False)
return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
......@@ -4,7 +4,7 @@ import torch
import torch.distributed as dist
from megatron import get_args, print_rank_0
from megatron import mpu
from megatron.core import mpu
from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch
......
......@@ -14,12 +14,10 @@ from megatron import fused_kernels
from megatron import get_adlr_autoresume
from megatron import get_args
from megatron import get_tensorboard_writer
from megatron import mpu
from megatron.core import mpu, tensor_parallel
from megatron.arguments import (parse_args, validate_args)
from megatron.checkpointing import load_args_from_checkpoint
from megatron.global_vars import set_global_variables
from megatron.mpu import (set_tensor_model_parallel_rank,
set_tensor_model_parallel_world_size)
from megatron.model.transformer import bias_dropout_add_fused_train
from megatron.model.fused_bias_gelu import bias_gelu
......@@ -64,13 +62,14 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
args = get_args()
if args.lazy_mpu_init:
# TODO is this still a necessary option?
args.use_cpu_initialization=True
# delayed initialization of DDP-related stuff
# We only set basic DDP globals
set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
# We only set basic DDP globals
mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
# and return function for external DDP manager
# to call when it has DDP initialized
set_tensor_model_parallel_rank(args.rank)
mpu.set_tensor_model_parallel_rank(args.rank)
return finish_mpu_init
else:
# Megatron's MPU is the master. Complete initialization right away.
......@@ -146,7 +145,7 @@ def _compile_dependencies():
def _initialize_distributed():
"""Initialize torch.distributed and mpu."""
"""Initialize torch.distributed and core model parallel."""
args = get_args()
device_count = torch.cuda.device_count()
......@@ -184,9 +183,14 @@ def _initialize_distributed():
print('model parallel is already initialized')
else:
mpu.initialize_model_parallel(args.tensor_model_parallel_size,
args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank)
args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank)
if args.rank == 0:
print(f'> initialized tensor model parallel with size '
f'{mpu.get_tensor_model_parallel_world_size()}')
print(f'> initialized pipeline model parallel with size '
f'{mpu.get_pipeline_model_parallel_world_size()}')
def _init_autoresume():
......@@ -210,7 +214,7 @@ def _set_random_seed(seed_, data_parallel_random_init=False):
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.device_count() > 0:
mpu.model_parallel_cuda_manual_seed(seed)
tensor_parallel.model_parallel_cuda_manual_seed(seed)
else:
raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
......
......@@ -5,7 +5,7 @@
import torch
from megatron import get_args
from megatron import mpu
from megatron.core import tensor_parallel
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
......@@ -61,7 +61,7 @@ class BertLMHead(MegatronModule):
args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
tensor_parallel.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
......@@ -110,10 +110,10 @@ def post_language_model_processing(lm_output, pooled_output,
# lm_logits : [s, b, h] and lm_labels: [s, b]
if fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
# [s, b] => [b s]
lm_loss = lm_loss.transpose(0,1).contiguous()
return lm_loss, binary_logits
......
......@@ -2,11 +2,11 @@ import os
import torch
import sys
from megatron import get_args, print_rank_0
from megatron import get_args, print_rank_0, get_tokenizer
from megatron.core import mpu
from megatron.checkpointing import fix_query_key_value_ordering
from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.checkpointing import get_checkpoint_name
from megatron import mpu, get_tokenizer
from megatron.model.bert_model import bert_position_ids
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import get_language_model
......
......@@ -5,7 +5,6 @@
import torch
from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
......
......@@ -8,7 +8,7 @@ import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from megatron import get_args
from megatron import mpu
from megatron.core import mpu
from .module import MegatronModule
......
......@@ -10,7 +10,7 @@ from torch.nn.parameter import Parameter
from torch.nn import init
import importlib
from megatron.mpu import make_viewless_tensor
from megatron.core.utils import make_viewless_tensor
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
......
......@@ -5,7 +5,7 @@
import torch
from megatron import get_args
from megatron import mpu
from megatron.core import tensor_parallel
from .module import MegatronModule
from .enums import AttnMaskType
......@@ -33,9 +33,9 @@ def post_language_model_processing(lm_output, labels, logit_weights,
labels = labels.transpose(0,1).contiguous()
if fp16_lm_cross_entropy:
assert output.dtype == torch.half
loss = mpu.vocab_parallel_cross_entropy(output, labels)
loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels)
else:
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)
# [s b] => [b, s]
loss = loss.transpose(0,1).contiguous()
......
......@@ -6,7 +6,7 @@ import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import mpu
from megatron.core import mpu, tensor_parallel
from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType
from megatron.model.transformer import ParallelTransformer
......@@ -26,20 +26,23 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
model_parallel and not args.sequence_parallel
else:
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_)
async_grad_allreduce = False
# Matrix multiply.
logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, word_embeddings_weight, bias,
args.gradient_accumulation_fusion,
async_grad_allreduce, args.sequence_parallel)
logits_parallel = tensor_parallel.linear_with_grad_accumulation_and_async_allreduce(
input=input_parallel,
weight=word_embeddings_weight,
bias=bias,
gradient_accumulation_fusion=args.gradient_accumulation_fusion,
async_grad_allreduce=async_grad_allreduce,
sequence_parallel_enabled=args.sequence_parallel)
# Gather if needed.
if parallel_output:
return logits_parallel
return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(num_tokentypes, add_pooler,
......@@ -103,7 +106,7 @@ class Pooler(MegatronModule):
# gather data along sequence dimensions
# same pooler is run on all tensor parallel nodes
if self.sequence_parallel:
hidden_states = mpu.gather_from_sequence_parallel_region(
hidden_states = tensor_parallel.gather_from_sequence_parallel_region(
hidden_states,
tensor_parallel_output_grad=False)
......@@ -143,9 +146,13 @@ class Embedding(MegatronModule):
args = get_args()
# Word embeddings (parallel).
self.word_embeddings = mpu.VocabParallelEmbedding(
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
vocab_size, self.hidden_size,
init_method=self.init_method)
init_method=self.init_method,
params_dtype=args.params_dtype,
use_cpu_initialization=args.use_cpu_initialization,
perform_initialization=args.perform_initialization
)
self._word_embeddings_key = 'word_embeddings'
# Position embedding (serial).
......@@ -222,8 +229,8 @@ class Embedding(MegatronModule):
# Dropout.
if self.sequence_parallel:
embeddings = mpu.scatter_to_sequence_parallel_region(embeddings)
with mpu.get_cuda_rng_tracker().fork():
embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
with tensor_parallel.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
embeddings = self.embedding_dropout(embeddings)
......
......@@ -7,7 +7,7 @@ from torch.autograd import Variable
from torch.nn.parameter import Parameter
from megatron import get_args
from megatron import mpu
from megatron.core import mpu, tensor_parallel
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
......@@ -76,9 +76,12 @@ class MegatronModule(torch.nn.Module):
self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.word_embeddings = mpu.VocabParallelEmbedding(
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std))
init_method=init_method_normal(args.init_method_std),
params_dtype=args.params_dtype,
use_cpu_initialization=args.use_cpu_initialization,
perform_initialization=args.perform_initialization)
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
......
......@@ -5,7 +5,6 @@
import torch
from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
......
......@@ -5,7 +5,7 @@ from megatron import get_args, print_rank_0
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.model import BertModel
from .module import MegatronModule
from megatron import mpu
from megatron.core import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
......
......@@ -4,10 +4,8 @@
import torch
from megatron import (
get_args,
mpu
)
from megatron import get_args
from megatron.core import tensor_parallel
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits, get_language_model
from megatron.model.transformer import LayerNorm
......@@ -151,10 +149,10 @@ class T5Model(MegatronModule):
lm_labels = lm_labels.transpose(0,1).contiguous()
if self.fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
# [s b] => [b s]
lm_loss = lm_loss.transpose(0,1).contiguous()
return lm_loss
......
......@@ -6,9 +6,9 @@ from contextlib import nullcontext
import torch
import torch.nn.functional as F
from megatron import get_timers, get_args, get_global_memory_buffer
from megatron import mpu
from megatron import get_timers, get_args, core
from .module import MegatronModule
from megatron.core import mpu, tensor_parallel
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
......@@ -32,7 +32,7 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
"""
class DropPath(MegatronModule):
"""Drop paths (Stochastic Depth) per sample
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
"""
......@@ -52,6 +52,17 @@ class DropPath(MegatronModule):
output = hidden_state.div(keep_prob) * random_tensor
return output
def _args_to_kwargs():
args = get_args()
common_kwargs = {
"params_dtype": args.params_dtype,
"use_cpu_initialization": args.use_cpu_initialization,
"perform_initialization": args.perform_initialization,
"gradient_accumulation_fusion": args.gradient_accumulation_fusion,
"sequence_parallel_enabled": args.sequence_parallel,
}
return common_kwargs
class ParallelMLP(MegatronModule):
"""MLP.
......@@ -65,13 +76,16 @@ class ParallelMLP(MegatronModule):
super(ParallelMLP, self).__init__()
args = get_args()
# Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear(
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
args.ffn_hidden_size,
gather_output=False,
init_method=init_method,
skip_bias_add=True)
skip_bias_add=True,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
......@@ -81,12 +95,13 @@ class ParallelMLP(MegatronModule):
self.activation_func = erf_gelu
# Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear(
self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
args.ffn_hidden_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True)
skip_bias_add=True,
**_args_to_kwargs())
def forward(self, hidden_states):
......@@ -136,7 +151,7 @@ class SwitchMLP(MegatronModule):
output_total = torch.empty_like(hidden_states)
output_bias_total = torch.empty_like(hidden_states)
#TODO (rprenger) This does each expert in serial, but it could be parallelized
for expert_num, expert in enumerate(self.experts):
local_indices = (max_ind == expert_num).nonzero()
hidden = hidden_states[local_indices,:]
......@@ -174,11 +189,11 @@ class CoreAttention(MegatronModule):
# Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(projection_size,
world_size)
self.hidden_size_per_attention_head = mpu.divide(
self.hidden_size_per_partition = core.utils.divide(projection_size,
world_size)
self.hidden_size_per_attention_head = core.utils.divide(
projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide(
self.num_attention_heads_per_partition = core.utils.divide(
args.num_attention_heads, world_size)
coeff = None
......@@ -221,7 +236,7 @@ class CoreAttention(MegatronModule):
output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = get_global_memory_buffer().get_tensor(
matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor(
(output_size[0]*output_size[1], output_size[2], output_size[3]),
query_layer.dtype, "mpu")
......@@ -247,7 +262,7 @@ class CoreAttention(MegatronModule):
# seem a bit unusual, but is taken from the original Transformer paper.
if not self.sequence_parallel:
with mpu.get_cuda_rng_tracker().fork():
with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
else:
attention_probs = self.attention_dropout(attention_probs)
......@@ -312,43 +327,51 @@ class ParallelAttention(MegatronModule):
# Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = mpu.divide(
self.hidden_size_per_attention_head = core.utils.divide(
projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide(
self.num_attention_heads_per_partition = core.utils.divide(
args.num_attention_heads, world_size)
# Strided linear layer.
if attention_type == AttnType.self_attn:
self.query_key_value = mpu.ColumnParallelLinear(
self.query_key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
3 * projection_size,
gather_output=False,
init_method=init_method)
init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
else:
assert attention_type == AttnType.cross_attn
self.query = mpu.ColumnParallelLinear(
self.query = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
projection_size,
gather_output=False,
init_method=init_method)
init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
self.key_value = mpu.ColumnParallelLinear(
self.key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
2 * projection_size,
gather_output=False,
init_method=init_method)
init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
self.core_attention = CoreAttention(self.layer_number,
self.attn_mask_type)
self.checkpoint_core_attention = args.recompute_granularity == 'selective'
# Output.
self.dense = mpu.RowParallelLinear(
self.dense = tensor_parallel.RowParallelLinear(
projection_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True)
skip_bias_add=True,
**_args_to_kwargs())
def _checkpointed_attention_forward(self, query_layer, key_layer,
value_layer, attention_mask):
......@@ -362,7 +385,7 @@ class ParallelAttention(MegatronModule):
value_layer, attention_mask)
return output_
hidden_states = mpu.checkpoint(
hidden_states = tensor_parallel.checkpoint(
custom_forward,
False, query_layer, key_layer, value_layer, attention_mask)
......@@ -415,7 +438,7 @@ class ParallelAttention(MegatronModule):
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer,
key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
......@@ -428,7 +451,7 @@ class ParallelAttention(MegatronModule):
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)
value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
......@@ -674,9 +697,9 @@ class ParallelTransformerLayer(MegatronModule):
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = mpu.make_viewless_tensor(inp = output,
requires_grad = output.requires_grad,
keep_graph = True)
output = core.utils.make_viewless_tensor(inp = output,
requires_grad = output.requires_grad,
keep_graph = True)
else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias,
......@@ -713,13 +736,65 @@ class NoopTransformerLayer(MegatronModule):
return hidden_states.clone()
def _get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False):
"""Compute the number of transformer layers resident on the current rank."""
if get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model:
assert args.pipeline_model_parallel_split_rank is not None
# When a standalone embedding stage is used, a rank is taken from
# the encoder's ranks, to be used for the encoder's embedding
# layer. This way, the rank referenced by the 'split rank' remains
# the same whether or not a standalone embedding stage is used.
num_ranks_in_encoder = (
args.pipeline_model_parallel_split_rank - 1
if args.standalone_embedding_stage else
args.pipeline_model_parallel_split_rank
)
num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
assert args.encoder_num_layers % num_ranks_in_encoder == 0, \
'encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder)
assert args.decoder_num_layers % num_ranks_in_decoder == 0, \
'decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder)
if is_pipeline_stage_before_split():
num_layers = (
0
if args.standalone_embedding_stage
and get_pipeline_model_parallel_rank() == 0 else
args.encoder_num_layers // num_ranks_in_encoder
)
else:
num_layers = args.decoder_num_layers // num_ranks_in_decoder
else:
assert args.num_layers == args.encoder_num_layers
assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'num_layers must be divisible by transformer_pipeline_model_parallel_size'
# When a standalone embedding stage is used, all transformer layers
# are divided among pipeline rank >= 1, while on pipeline rank 0,
# ranks either contain the input embedding layer (virtual pp rank 0),
# or no layers at all (virtual pp rank >= 1).
num_layers = (
0
if args.standalone_embedding_stage
and get_pipeline_model_parallel_rank() == 0 else
args.num_layers // args.transformer_pipeline_model_parallel_size
)
else:
if not is_decoder:
num_layers = args.encoder_num_layers
else:
num_layers = args.decoder_num_layers
return num_layers
class ParallelTransformer(MegatronModule):
"""Transformer class."""
def __init__(self, init_method, output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
post_layer_norm=True,
post_layer_norm=True,
pre_process=True, post_process=True,
drop_path_rate=0.0):
super(ParallelTransformer, self).__init__()
......@@ -745,7 +820,7 @@ class ParallelTransformer(MegatronModule):
self.sequence_parallel = args.sequence_parallel
# Number of layers.
self.num_layers = mpu.get_num_layers(
self.num_layers = _get_num_layers(
args,
args.model_type == ModelType.encoder_and_decoder,
layer_type == LayerType.decoder)
......@@ -840,7 +915,7 @@ class ParallelTransformer(MegatronModule):
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states = mpu.checkpoint(
hidden_states = tensor_parallel.checkpoint(
custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
......@@ -852,7 +927,7 @@ class ParallelTransformer(MegatronModule):
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.recompute_num_layers:
hidden_states = mpu.checkpoint(
hidden_states = tensor_parallel.checkpoint(
custom(l, l + 1),
self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
......@@ -898,19 +973,19 @@ class ParallelTransformer(MegatronModule):
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states = mpu.make_viewless_tensor(
hidden_states = core.utils.make_viewless_tensor(
hidden_states,
requires_grad=True,
keep_graph=True,
)
if self.sequence_parallel:
rng_context = mpu.get_cuda_rng_tracker().fork()
rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
else:
rng_context = nullcontext()
......
import torch.nn.functional as F
import torch
from megatron import print_rank_0, get_args, mpu
from megatron import print_rank_0, get_args
from megatron.core import mpu
from megatron.data.vit_dataset import ClassificationTransform
from megatron.data.image_folder import ImageFolder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment