Commit c2ea914f authored by Jared Casper's avatar Jared Casper
Browse files

Move layers from mpu to core.tensor_parallel.

parent 209f91c9
......@@ -10,7 +10,6 @@ from .global_vars import get_tokenizer
from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume
from .global_vars import get_timers
from .global_vars import get_global_memory_buffer
from .initialize import initialize_megatron
from .utils import (print_rank_0,
......
......@@ -9,7 +9,7 @@ import numpy as np
import torch
from megatron import (mpu,
from megatron import (core,
update_num_microbatches)
from .global_vars import get_args
from .utils import (unwrap_model,
......@@ -79,11 +79,11 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
# Use both the tensor and pipeline MP rank.
if pipeline_parallel is None:
pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1)
pipeline_parallel = (core.get_pipeline_model_parallel_world_size() > 1)
if tensor_rank is None:
tensor_rank = mpu.get_tensor_model_parallel_rank()
tensor_rank = core.get_tensor_model_parallel_rank()
if pipeline_rank is None:
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
pipeline_rank = core.get_pipeline_model_parallel_rank()
# Use both the tensor and pipeline MP rank. If using the distributed
# optimizer, then the optimizer's path must additionally include the
......@@ -98,7 +98,7 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
if use_distributed_optimizer:
model_name = os.path.join(common_path, "model_rng.pt")
optim_name = os.path.join(
common_path + "_%03d" % mpu.get_data_parallel_rank(),
common_path + "_%03d" % core.get_data_parallel_rank(),
"optim.pt")
else:
model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt")
......@@ -185,18 +185,18 @@ def get_rng_state():
'np_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state(),
'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states()}
'rng_tracker_states': core.tensor_parallel.get_cuda_rng_tracker().get_states()}
rng_state_list = None
if torch.distributed.is_initialized() and \
mpu.get_data_parallel_world_size() > 1 and \
core.get_data_parallel_world_size() > 1 and \
args.data_parallel_random_init:
rng_state_list = \
[None for i in range(mpu.get_data_parallel_world_size())]
[None for i in range(core.get_data_parallel_world_size())]
torch.distributed.all_gather_object(
rng_state_list,
rng_state,
group=mpu.get_data_parallel_group())
group=core.get_data_parallel_group())
else:
rng_state_list = [rng_state]
......@@ -223,7 +223,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
# Collect args, model, RNG.
model_state_dict = {}
if not torch.distributed.is_initialized() \
or mpu.get_data_parallel_rank() == 0:
or core.get_data_parallel_rank() == 0:
# Arguments, iteration, and model.
model_state_dict['args'] = args
......@@ -233,7 +233,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
model_state_dict['model'] = model[0].state_dict_for_save_checkpoint()
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
core.set_virtual_pipeline_model_parallel_rank(i)
model_state_dict['model%d' % i] = \
model[i].state_dict_for_save_checkpoint()
......@@ -246,7 +246,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
optim_state_dict = {}
if not args.no_save_optim \
and (not torch.distributed.is_initialized()
or mpu.get_data_parallel_rank() == 0
or core.get_data_parallel_rank() == 0
or args.use_distributed_optimizer):
# Optimizer stuff.
......@@ -548,7 +548,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
model[0].load_state_dict(model_state_dict['model'], strict=strict)
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
core.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(model_state_dict['model%d' % i], strict=strict)
# Fix up query/key/value matrix ordering if needed
......@@ -580,7 +580,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# access rng_state for data parallel rank
if args.data_parallel_random_init:
rng_state = model_state_dict['rng_state'][mpu.get_data_parallel_rank()]
rng_state = model_state_dict['rng_state'][core.get_data_parallel_rank()]
else:
rng_state = model_state_dict['rng_state'][0]
random.setstate(rng_state['random_rng_state'])
......@@ -590,7 +590,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array
if not rng_state['rng_tracker_states']:
raise KeyError
mpu.get_cuda_rng_tracker().set_states(
core.tensor_parallel.get_cuda_rng_tracker().set_states(
rng_state['rng_tracker_states'])
else: # backward compatability
random.setstate(model_state_dict['random_rng_state'])
......@@ -600,7 +600,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array
if not model_state_dict['rng_tracker_states']:
raise KeyError
mpu.get_cuda_rng_tracker().set_states(
core.tensor_parallel.get_cuda_rng_tracker().set_states(
model_state_dict['rng_tracker_states'])
except KeyError:
print_rank_0('Unable to load rng state from checkpoint {}. '
......@@ -640,7 +640,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
args.use_distributed_optimizer,
release=False)
if mpu.get_data_parallel_rank() == 0:
if core.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
......@@ -656,7 +656,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
model[0].load_state_dict(ret_state_dict)
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
if core.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
return model
from .parallel_state import (
initialize_model_parallel,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
get_pipeline_model_parallel_world_size,
get_pipeline_model_parallel_rank,
get_virtual_pipeline_model_parallel_rank, set_virtual_pipeline_model_parallel_rank,
get_data_parallel_world_size,
get_data_parallel_rank,
get_global_memory_buffer,
get_num_layers,
)
from megatron.core import tensor_parallel
......@@ -5,6 +5,8 @@
import torch
from typing import Optional
from .utils import GlobalMemoryBuffer
# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Inter-layer model parallel group that the current rank belongs to.
......@@ -42,7 +44,8 @@ _PIPELINE_GLOBAL_RANKS = None
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None
# Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER = None
def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization"""
......@@ -195,6 +198,12 @@ def initialize_model_parallel(
if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
# Initialize global memory buffer
# This isn't really "parallel state" but there isn't another good place to
# put this. If we end up with a more generic initialization of megatron-core
# we could stick it there
_set_global_memory_buffer()
def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
......@@ -506,6 +515,18 @@ def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return torch.distributed.get_rank(group=get_data_parallel_group())
def _set_global_memory_buffer():
"""Initialize global buffer"""
global _GLOBAL_MEMORY_BUFFER
assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized'
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
def get_global_memory_buffer():
assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
return _GLOBAL_MEMORY_BUFFER
def destroy_model_parallel():
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
......
from .cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data
from .layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes,
param_is_not_tensor_parallel_duplicate,
linear_with_grad_accumulation_and_async_allreduce
)
from .mappings import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
scatter_to_tensor_model_parallel_region,
scatter_to_sequence_parallel_region,
)
from .random import (
checkpoint,
get_cuda_rng_tracker,
model_parallel_cuda_manual_seed
)
from .utils import split_tensor_along_last_dim
__all__ = [
# cross_entropy.py
"vocab_parallel_cross_entropy",
# data.py
"broadcast_data",
#layers.py
"ColumnParallelLinear",
"RowParallelLinear",
"VocabParallelEmbedding",
"set_defaults_if_not_set_tensor_model_parallel_attributes",
"copy_tensor_model_parallel_attributes",
"param_is_not_tensor_parallel_duplicate",
"linear_with_grad_accumulation_and_async_allreduce",
# mappings.py
"copy_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region",
"gather_from_sequence_parallel_region",
# "reduce_from_tensor_model_parallel_region",
"scatter_to_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region",
# random.py
"checkpoint",
"get_cuda_rng_tracker",
"model_parallel_cuda_manual_seed",
# utils.py
"split_tensor_along_last_dim",
]
......@@ -2,7 +2,11 @@
import torch
from .initialize import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank
from megatron.core.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group,
)
from .utils import split_tensor_along_last_dim
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
......@@ -11,12 +10,12 @@ from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager
from torch.utils.checkpoint import detach_variable
from megatron.memory import allocate_mem_buff
from .initialize import get_data_parallel_rank
from .initialize import get_tensor_model_parallel_group
from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size
from megatron.core.parallel_state import (
get_data_parallel_rank,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
# Default name for the model parallel rng tracker.
......@@ -89,85 +88,6 @@ def gather_split_1d_tensor(tensor):
return gathered
def _kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor.
View tensors have the undesirable side-affect of retaining a reference
to the originally-viewed tensor, even after manually setting the '.data'
field. This method creates a new tensor that links to the old tensor's
data, without linking the viewed tensor, referenced via the '._base'
field.
'''
out = torch.empty(
(1,),
dtype = inp.dtype,
device = inp.device,
requires_grad = requires_grad,
)
out.data = inp.data
return out
class MakeViewlessTensor(torch.autograd.Function):
'''
Autograd function to make a viewless tensor.
This function should be used in cases where the computation graph needs
to be propagated, but we only want a viewless tensor (e.g.,
ParallelTransformer's hidden_states). Call this function by passing
'keep_graph = True' to 'make_viewless_tensor()'.
'''
@staticmethod
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def make_viewless_tensor(inp, requires_grad, keep_graph):
'''
Entry-point for creating viewless tensors.
This method should be used, rather than calling 'MakeViewlessTensor'
or '_kernel_make_viewless_tensor' directly. This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor.
'''
# return tensor as-is, if not a 'view'
if inp._base is None:
return inp
# create viewless tensor
if keep_graph:
return MakeViewlessTensor.apply(inp, requires_grad)
else:
return _kernel_make_viewless_tensor(inp, requires_grad)
def assert_viewless_tensor(tensor, extra_msg = None):
'''Assert that a tensor is not a view (i.e., its '._base' field is
not set).'''
if isinstance(tensor, list):
[ assert_viewless_tensor(t) for t in tensor ]
return tensor
if not isinstance(tensor, torch.Tensor):
return tensor
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). %s"
) % extra_msg
return tensor
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
'''Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
'''
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
tensor.data = new_data_tensor
class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
......
......@@ -5,19 +5,6 @@ from typing import List, Sequence
from megatron.core.utils import divide
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, '{} is not divisible by {}'.format(
numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
......
"""Utility functions used through Megatron core"""
"""Utility functions used throughout Megatron core"""
from functools import reduce
import operator
import torch
from megatron.core import parallel_state
......@@ -46,3 +49,101 @@ def gather_split_1d_tensor(tensor):
group=parallel_state.get_tensor_model_parallel_group()
)
return gathered
class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
are not used concurrently."""
def __init__(self):
self.buffer = {}
def get_tensor(self, tensor_shape, dtype, name):
required_len = reduce(operator.mul, tensor_shape, 1)
if self.buffer.get((name, dtype), None) is None or \
self.buffer[(name, dtype)].numel() < required_len:
self.buffer[(name, dtype)] = \
torch.empty(required_len,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False)
return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
def _kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor.
View tensors have the undesirable side-affect of retaining a reference
to the originally-viewed tensor, even after manually setting the '.data'
field. This method creates a new tensor that links to the old tensor's
data, without linking the viewed tensor, referenced via the '._base'
field.
'''
out = torch.empty(
(1,),
dtype = inp.dtype,
device = inp.device,
requires_grad = requires_grad,
)
out.data = inp.data
return out
class MakeViewlessTensor(torch.autograd.Function):
'''
Autograd function to make a viewless tensor.
This function should be used in cases where the computation graph needs
to be propagated, but we only want a viewless tensor (e.g.,
ParallelTransformer's hidden_states). Call this function by passing
'keep_graph = True' to 'make_viewless_tensor()'.
'''
@staticmethod
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def make_viewless_tensor(inp, requires_grad, keep_graph):
'''
Entry-point for creating viewless tensors.
This method should be used, rather than calling 'MakeViewlessTensor'
or '_kernel_make_viewless_tensor' directly. This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor.
'''
# return tensor as-is, if not a 'view'
if inp._base is None:
return inp
# create viewless tensor
if keep_graph:
return MakeViewlessTensor.apply(inp, requires_grad)
else:
return _kernel_make_viewless_tensor(inp, requires_grad)
def assert_viewless_tensor(tensor, extra_msg = None):
'''Assert that a tensor is not a view (i.e., its '._base' field is
not set).'''
if isinstance(tensor, list):
[ assert_viewless_tensor(t) for t in tensor ]
return tensor
if not isinstance(tensor, torch.Tensor):
return tensor
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). %s"
) % extra_msg
return tensor
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
'''Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
'''
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
tensor.data = new_data_tensor
......@@ -4,8 +4,6 @@
import os
import sys
from functools import reduce
import operator
import torch
from megatron import dist_signal_handler
......@@ -20,7 +18,6 @@ _GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None
_GLOBAL_SIGNAL_HANDLER = None
_GLOBAL_MEMORY_BUFFER = None
def get_args():
"""Return arguments."""
......@@ -70,11 +67,6 @@ def get_signal_handler():
return _GLOBAL_SIGNAL_HANDLER
def get_global_memory_buffer():
_ensure_var_is_initialized(_GLOBAL_MEMORY_BUFFER, 'global memory buffer')
return _GLOBAL_MEMORY_BUFFER
def _set_signal_handler():
global _GLOBAL_SIGNAL_HANDLER
_ensure_var_is_not_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler')
......@@ -96,7 +88,6 @@ def set_global_variables(args):
_set_tensorboard_writer(args)
_set_adlr_autoresume(args)
_set_timers(args)
_set_global_memory_buffer()
if args.exit_signal_handler:
_set_signal_handler()
......@@ -176,13 +167,6 @@ def _set_timers(args):
_GLOBAL_TIMERS = Timers(args.timing_log_level, args.timing_log_option)
def _set_global_memory_buffer():
"""Initialize global buffer"""
global _GLOBAL_MEMORY_BUFFER
_ensure_var_is_not_initialized(_GLOBAL_MEMORY_BUFFER, 'global memory buffer')
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
def _ensure_var_is_initialized(var, name):
"""Make sure the input variable is not None."""
assert var is not None, '{} is not initialized.'.format(name)
......@@ -194,22 +178,3 @@ def _ensure_var_is_not_initialized(var, name):
class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
are not used concurrently."""
def __init__(self):
self.buffer = {}
def get_tensor(self, tensor_shape, dtype, name):
required_len = reduce(operator.mul, tensor_shape, 1)
if self.buffer.get((name, dtype), None) is None or \
self.buffer[(name, dtype)].numel() < required_len:
self.buffer[(name, dtype)] = \
torch.empty(required_len,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False)
return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
......@@ -219,7 +219,7 @@ def _set_random_seed(seed_, data_parallel_random_init=False):
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.device_count() > 0:
mpu.model_parallel_cuda_manual_seed(seed)
core.tensor_parallel.model_parallel_cuda_manual_seed(seed)
else:
raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
......
......@@ -10,7 +10,7 @@ from torch.nn.parameter import Parameter
from torch.nn import init
import importlib
from megatron.mpu import make_viewless_tensor
from megatron.core.utils import make_viewless_tensor
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
......
......@@ -6,7 +6,7 @@ import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import mpu
from megatron import core
from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType
from megatron.model.transformer import ParallelTransformer
......@@ -22,24 +22,27 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
if args.async_tensor_model_parallel_allreduce or\
args.sequence_parallel:
input_parallel = input_
model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
model_parallel = core.get_tensor_model_parallel_world_size() > 1
async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
model_parallel and not args.sequence_parallel
else:
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
input_parallel = core.tensor_parallel.copy_to_tensor_model_parallel_region(input_)
async_grad_allreduce = False
# Matrix multiply.
logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, word_embeddings_weight, bias,
args.gradient_accumulation_fusion,
async_grad_allreduce, args.sequence_parallel)
logits_parallel = core.tensor_parallel.linear_with_grad_accumulation_and_async_allreduce(
input=input_parallel,
weight=word_embeddings_weight,
bias=bias,
gradient_accumulation_fusion=args.gradient_accumulation_fusion,
async_grad_allreduce=async_grad_allreduce,
sequence_parallel_enabled=args.sequence_parallel)
# Gather if needed.
if parallel_output:
return logits_parallel
return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
return core.tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(num_tokentypes, add_pooler,
......@@ -103,7 +106,7 @@ class Pooler(MegatronModule):
# gather data along sequence dimensions
# same pooler is run on all tensor parallel nodes
if self.sequence_parallel:
hidden_states = mpu.gather_from_sequence_parallel_region(
hidden_states = core.tensor_parallel.gather_from_sequence_parallel_region(
hidden_states,
tensor_parallel_output_grad=False)
......@@ -143,9 +146,13 @@ class Embedding(MegatronModule):
args = get_args()
# Word embeddings (parallel).
self.word_embeddings = mpu.VocabParallelEmbedding(
self.word_embeddings = core.tensor_parallel.VocabParallelEmbedding(
vocab_size, self.hidden_size,
init_method=self.init_method)
init_method=self.init_method,
params_dtype=args.params_dtype,
use_cpu_initialization=args.use_cpu_initialization,
perform_initialization=args.perform_initialization
)
self._word_embeddings_key = 'word_embeddings'
# Position embedding (serial).
......@@ -222,8 +229,8 @@ class Embedding(MegatronModule):
# Dropout.
if self.sequence_parallel:
embeddings = mpu.scatter_to_sequence_parallel_region(embeddings)
with mpu.get_cuda_rng_tracker().fork():
embeddings = core.tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
with core.tensor_parallel.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
embeddings = self.embedding_dropout(embeddings)
......
......@@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
from megatron import get_args
from megatron import mpu
from megatron import core
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
......@@ -76,9 +77,12 @@ class MegatronModule(torch.nn.Module):
self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.word_embeddings = mpu.VocabParallelEmbedding(
self.word_embeddings = core.tensor_parallel.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std))
init_method=init_method_normal(args.init_method_std),
params_dtype=args.params_dtype,
use_cpu_initialization=args.use_cpu_initialization,
perform_initialization=args.perform_initialization)
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
......
......@@ -6,8 +6,9 @@ from contextlib import nullcontext
import torch
import torch.nn.functional as F
from megatron import get_timers, get_args, get_global_memory_buffer
from megatron import mpu
from megatron import get_timers, get_args
from megatron.core import get_global_memory_buffer
from megatron import core
from .module import MegatronModule
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
from megatron.model import LayerNorm
......@@ -32,7 +33,7 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
"""
class DropPath(MegatronModule):
"""Drop paths (Stochastic Depth) per sample
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
"""
......@@ -52,6 +53,17 @@ class DropPath(MegatronModule):
output = hidden_state.div(keep_prob) * random_tensor
return output
def _args_to_kwargs():
args = get_args()
common_kwargs = {
"params_dtype": args.params_dtype,
"use_cpu_initialization": args.use_cpu_initialization,
"perform_initialization": args.perform_initialization,
"gradient_accumulation_fusion": args.gradient_accumulation_fusion,
"sequence_parallel_enabled": args.sequence_parallel,
}
return common_kwargs
class ParallelMLP(MegatronModule):
"""MLP.
......@@ -65,13 +77,16 @@ class ParallelMLP(MegatronModule):
super(ParallelMLP, self).__init__()
args = get_args()
# Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear(
self.dense_h_to_4h = core.tensor_parallel.ColumnParallelLinear(
args.hidden_size,
args.ffn_hidden_size,
gather_output=False,
init_method=init_method,
skip_bias_add=True)
skip_bias_add=True,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
......@@ -81,12 +96,13 @@ class ParallelMLP(MegatronModule):
self.activation_func = erf_gelu
# Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear(
self.dense_4h_to_h = core.tensor_parallel.RowParallelLinear(
args.ffn_hidden_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True)
skip_bias_add=True,
**_args_to_kwargs())
def forward(self, hidden_states):
......@@ -136,7 +152,7 @@ class SwitchMLP(MegatronModule):
output_total = torch.empty_like(hidden_states)
output_bias_total = torch.empty_like(hidden_states)
#TODO (rprenger) This does each expert in serial, but it could be parallelized
for expert_num, expert in enumerate(self.experts):
local_indices = (max_ind == expert_num).nonzero()
hidden = hidden_states[local_indices,:]
......@@ -173,12 +189,12 @@ class CoreAttention(MegatronModule):
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(projection_size,
world_size)
self.hidden_size_per_attention_head = mpu.divide(
world_size = core.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = core.utils.divide(projection_size,
world_size)
self.hidden_size_per_attention_head = core.utils.divide(
projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide(
self.num_attention_heads_per_partition = core.utils.divide(
args.num_attention_heads, world_size)
coeff = None
......@@ -247,7 +263,7 @@ class CoreAttention(MegatronModule):
# seem a bit unusual, but is taken from the original Transformer paper.
if not self.sequence_parallel:
with mpu.get_cuda_rng_tracker().fork():
with core.tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
else:
attention_probs = self.attention_dropout(attention_probs)
......@@ -311,44 +327,52 @@ class ParallelAttention(MegatronModule):
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = mpu.divide(
world_size = core.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = core.utils.divide(
projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide(
self.num_attention_heads_per_partition = core.utils.divide(
args.num_attention_heads, world_size)
# Strided linear layer.
if attention_type == AttnType.self_attn:
self.query_key_value = mpu.ColumnParallelLinear(
self.query_key_value = core.tensor_parallel.ColumnParallelLinear(
args.hidden_size,
3 * projection_size,
gather_output=False,
init_method=init_method)
init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
else:
assert attention_type == AttnType.cross_attn
self.query = mpu.ColumnParallelLinear(
self.query = core.tensor_parallel.ColumnParallelLinear(
args.hidden_size,
projection_size,
gather_output=False,
init_method=init_method)
init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
self.key_value = mpu.ColumnParallelLinear(
self.key_value = core.tensor_parallel.ColumnParallelLinear(
args.hidden_size,
2 * projection_size,
gather_output=False,
init_method=init_method)
init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
self.core_attention = CoreAttention(self.layer_number,
self.attn_mask_type)
self.checkpoint_core_attention = args.recompute_granularity == 'selective'
# Output.
self.dense = mpu.RowParallelLinear(
self.dense = core.tensor_parallel.RowParallelLinear(
projection_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True)
skip_bias_add=True,
**_args_to_kwargs())
def _checkpointed_attention_forward(self, query_layer, key_layer,
value_layer, attention_mask):
......@@ -362,7 +386,7 @@ class ParallelAttention(MegatronModule):
value_layer, attention_mask)
return output_
hidden_states = mpu.checkpoint(
hidden_states = core.tensor_parallel.checkpoint(
custom_forward,
False, query_layer, key_layer, value_layer, attention_mask)
......@@ -415,7 +439,7 @@ class ParallelAttention(MegatronModule):
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer,
key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
value_layer) = core.tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
......@@ -428,7 +452,7 @@ class ParallelAttention(MegatronModule):
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)
value_layer) = core.tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
......@@ -674,9 +698,9 @@ class ParallelTransformerLayer(MegatronModule):
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = mpu.make_viewless_tensor(inp = output,
requires_grad = output.requires_grad,
keep_graph = True)
output = core.utils.make_viewless_tensor(inp = output,
requires_grad = output.requires_grad,
keep_graph = True)
else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias,
......@@ -719,7 +743,7 @@ class ParallelTransformer(MegatronModule):
def __init__(self, init_method, output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
post_layer_norm=True,
post_layer_norm=True,
pre_process=True, post_process=True,
drop_path_rate=0.0):
super(ParallelTransformer, self).__init__()
......@@ -745,7 +769,7 @@ class ParallelTransformer(MegatronModule):
self.sequence_parallel = args.sequence_parallel
# Number of layers.
self.num_layers = mpu.get_num_layers(
self.num_layers = core.get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder)
self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
......@@ -775,21 +799,21 @@ class ParallelTransformer(MegatronModule):
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
offset = core.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size) + \
(mpu.get_pipeline_model_parallel_rank() * self.num_layers)
(core.get_pipeline_model_parallel_rank() * self.num_layers)
else:
# Each stage gets a contiguous set of layers.
if args.model_type == ModelType.encoder_and_decoder and \
mpu.get_pipeline_model_parallel_world_size() > 1:
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
core.get_pipeline_model_parallel_world_size() > 1:
pipeline_rank = core.get_pipeline_model_parallel_rank()
if layer_type == LayerType.encoder:
offset = pipeline_rank * self.num_layers
else:
num_ranks_in_enc = args.pipeline_model_parallel_split_rank
offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
else:
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
offset = core.get_pipeline_model_parallel_rank() * self.num_layers
if self.num_layers == 0:
# When a standalone embedding stage is used (e.g.,
......@@ -838,7 +862,7 @@ class ParallelTransformer(MegatronModule):
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states = mpu.checkpoint(
hidden_states = core.tensor_parallel.checkpoint(
custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
......@@ -850,7 +874,7 @@ class ParallelTransformer(MegatronModule):
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.recompute_num_layers:
hidden_states = mpu.checkpoint(
hidden_states = core.tensor_parallel.checkpoint(
custom(l, l + 1),
self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
......@@ -896,19 +920,19 @@ class ParallelTransformer(MegatronModule):
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states = mpu.make_viewless_tensor(
hidden_states = core.utils.make_viewless_tensor(
hidden_states,
requires_grad=True,
keep_graph=True,
)
if self.sequence_parallel:
rng_context = mpu.get_cuda_rng_tracker().fork()
rng_context = core.tensor_parallel.get_cuda_rng_tracker().fork()
else:
rng_context = nullcontext()
......
......@@ -32,30 +32,6 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi
from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized
from .layers import LinearWithGradAccumulationAndAsyncCommunication
from .layers import ColumnParallelLinear
from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding
from .layers import (set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes)
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import scatter_to_sequence_parallel_region
from .mappings import gather_from_sequence_parallel_region
from .mappings import reduce_scatter_to_sequence_parallel_region
from .random import checkpoint
from .random import get_cuda_rng_tracker
from .random import model_parallel_cuda_manual_seed
from .random import gather_split_1d_tensor
from .random import split_tensor_into_1d_equal_chunks
from .random import make_viewless_tensor
from .random import assert_viewless_tensor
from .random import safely_set_viewless_tensor_data
from .utils import divide
from .utils import split_tensor_along_last_dim
......@@ -9,7 +9,7 @@ from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
def clip_grad_norm_fp32(parameters, grads_for_norm,
......
......@@ -11,7 +11,7 @@ from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
......
......@@ -12,11 +12,12 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from megatron import get_timers
from megatron import mpu
from megatron import core
from megatron import print_rank_0
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
from megatron.utils import unwrap_model
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
......@@ -527,8 +528,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
# Create a copy
main_param = param.detach().clone().float()
# Copy tensor model parallel attributes.
mpu.copy_tensor_model_parallel_attributes(main_param,
param)
core.tensor_parallel.copy_tensor_model_parallel_attributes(main_param,
param)
if hasattr(param, 'shared'):
main_param.shared = param.shared
# Replace the optimizer params with the new fp32 copy.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment