Commit c2ea914f authored by Jared Casper's avatar Jared Casper
Browse files

Move layers from mpu to core.tensor_parallel.

parent 209f91c9
...@@ -10,7 +10,6 @@ from .global_vars import get_tokenizer ...@@ -10,7 +10,6 @@ from .global_vars import get_tokenizer
from .global_vars import get_tensorboard_writer from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume from .global_vars import get_adlr_autoresume
from .global_vars import get_timers from .global_vars import get_timers
from .global_vars import get_global_memory_buffer
from .initialize import initialize_megatron from .initialize import initialize_megatron
from .utils import (print_rank_0, from .utils import (print_rank_0,
......
...@@ -9,7 +9,7 @@ import numpy as np ...@@ -9,7 +9,7 @@ import numpy as np
import torch import torch
from megatron import (mpu, from megatron import (core,
update_num_microbatches) update_num_microbatches)
from .global_vars import get_args from .global_vars import get_args
from .utils import (unwrap_model, from .utils import (unwrap_model,
...@@ -79,11 +79,11 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, ...@@ -79,11 +79,11 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
# Use both the tensor and pipeline MP rank. # Use both the tensor and pipeline MP rank.
if pipeline_parallel is None: if pipeline_parallel is None:
pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1) pipeline_parallel = (core.get_pipeline_model_parallel_world_size() > 1)
if tensor_rank is None: if tensor_rank is None:
tensor_rank = mpu.get_tensor_model_parallel_rank() tensor_rank = core.get_tensor_model_parallel_rank()
if pipeline_rank is None: if pipeline_rank is None:
pipeline_rank = mpu.get_pipeline_model_parallel_rank() pipeline_rank = core.get_pipeline_model_parallel_rank()
# Use both the tensor and pipeline MP rank. If using the distributed # Use both the tensor and pipeline MP rank. If using the distributed
# optimizer, then the optimizer's path must additionally include the # optimizer, then the optimizer's path must additionally include the
...@@ -98,7 +98,7 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, ...@@ -98,7 +98,7 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
if use_distributed_optimizer: if use_distributed_optimizer:
model_name = os.path.join(common_path, "model_rng.pt") model_name = os.path.join(common_path, "model_rng.pt")
optim_name = os.path.join( optim_name = os.path.join(
common_path + "_%03d" % mpu.get_data_parallel_rank(), common_path + "_%03d" % core.get_data_parallel_rank(),
"optim.pt") "optim.pt")
else: else:
model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt") model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt")
...@@ -185,18 +185,18 @@ def get_rng_state(): ...@@ -185,18 +185,18 @@ def get_rng_state():
'np_rng_state': np.random.get_state(), 'np_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(), 'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state(), 'cuda_rng_state': torch.cuda.get_rng_state(),
'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states()} 'rng_tracker_states': core.tensor_parallel.get_cuda_rng_tracker().get_states()}
rng_state_list = None rng_state_list = None
if torch.distributed.is_initialized() and \ if torch.distributed.is_initialized() and \
mpu.get_data_parallel_world_size() > 1 and \ core.get_data_parallel_world_size() > 1 and \
args.data_parallel_random_init: args.data_parallel_random_init:
rng_state_list = \ rng_state_list = \
[None for i in range(mpu.get_data_parallel_world_size())] [None for i in range(core.get_data_parallel_world_size())]
torch.distributed.all_gather_object( torch.distributed.all_gather_object(
rng_state_list, rng_state_list,
rng_state, rng_state,
group=mpu.get_data_parallel_group()) group=core.get_data_parallel_group())
else: else:
rng_state_list = [rng_state] rng_state_list = [rng_state]
...@@ -223,7 +223,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): ...@@ -223,7 +223,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
# Collect args, model, RNG. # Collect args, model, RNG.
model_state_dict = {} model_state_dict = {}
if not torch.distributed.is_initialized() \ if not torch.distributed.is_initialized() \
or mpu.get_data_parallel_rank() == 0: or core.get_data_parallel_rank() == 0:
# Arguments, iteration, and model. # Arguments, iteration, and model.
model_state_dict['args'] = args model_state_dict['args'] = args
...@@ -233,7 +233,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): ...@@ -233,7 +233,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
model_state_dict['model'] = model[0].state_dict_for_save_checkpoint() model_state_dict['model'] = model[0].state_dict_for_save_checkpoint()
else: else:
for i in range(len(model)): for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i) core.set_virtual_pipeline_model_parallel_rank(i)
model_state_dict['model%d' % i] = \ model_state_dict['model%d' % i] = \
model[i].state_dict_for_save_checkpoint() model[i].state_dict_for_save_checkpoint()
...@@ -246,7 +246,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): ...@@ -246,7 +246,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
optim_state_dict = {} optim_state_dict = {}
if not args.no_save_optim \ if not args.no_save_optim \
and (not torch.distributed.is_initialized() and (not torch.distributed.is_initialized()
or mpu.get_data_parallel_rank() == 0 or core.get_data_parallel_rank() == 0
or args.use_distributed_optimizer): or args.use_distributed_optimizer):
# Optimizer stuff. # Optimizer stuff.
...@@ -548,7 +548,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -548,7 +548,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
model[0].load_state_dict(model_state_dict['model'], strict=strict) model[0].load_state_dict(model_state_dict['model'], strict=strict)
else: else:
for i in range(len(model)): for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i) core.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(model_state_dict['model%d' % i], strict=strict) model[i].load_state_dict(model_state_dict['model%d' % i], strict=strict)
# Fix up query/key/value matrix ordering if needed # Fix up query/key/value matrix ordering if needed
...@@ -580,7 +580,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -580,7 +580,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# access rng_state for data parallel rank # access rng_state for data parallel rank
if args.data_parallel_random_init: if args.data_parallel_random_init:
rng_state = model_state_dict['rng_state'][mpu.get_data_parallel_rank()] rng_state = model_state_dict['rng_state'][core.get_data_parallel_rank()]
else: else:
rng_state = model_state_dict['rng_state'][0] rng_state = model_state_dict['rng_state'][0]
random.setstate(rng_state['random_rng_state']) random.setstate(rng_state['random_rng_state'])
...@@ -590,7 +590,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -590,7 +590,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array # Check for empty states array
if not rng_state['rng_tracker_states']: if not rng_state['rng_tracker_states']:
raise KeyError raise KeyError
mpu.get_cuda_rng_tracker().set_states( core.tensor_parallel.get_cuda_rng_tracker().set_states(
rng_state['rng_tracker_states']) rng_state['rng_tracker_states'])
else: # backward compatability else: # backward compatability
random.setstate(model_state_dict['random_rng_state']) random.setstate(model_state_dict['random_rng_state'])
...@@ -600,7 +600,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -600,7 +600,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array # Check for empty states array
if not model_state_dict['rng_tracker_states']: if not model_state_dict['rng_tracker_states']:
raise KeyError raise KeyError
mpu.get_cuda_rng_tracker().set_states( core.tensor_parallel.get_cuda_rng_tracker().set_states(
model_state_dict['rng_tracker_states']) model_state_dict['rng_tracker_states'])
except KeyError: except KeyError:
print_rank_0('Unable to load rng state from checkpoint {}. ' print_rank_0('Unable to load rng state from checkpoint {}. '
...@@ -640,7 +640,7 @@ def load_biencoder_checkpoint(model, only_query_model=False, ...@@ -640,7 +640,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
args.use_distributed_optimizer, args.use_distributed_optimizer,
release=False) release=False)
if mpu.get_data_parallel_rank() == 0: if core.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format( print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name)) torch.distributed.get_rank(), checkpoint_name))
...@@ -656,7 +656,7 @@ def load_biencoder_checkpoint(model, only_query_model=False, ...@@ -656,7 +656,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
model[0].load_state_dict(ret_state_dict) model[0].load_state_dict(ret_state_dict)
torch.distributed.barrier() torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0: if core.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name)) print(' successfully loaded {}'.format(checkpoint_name))
return model return model
from .parallel_state import ( from .parallel_state import (
initialize_model_parallel, initialize_model_parallel,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
get_pipeline_model_parallel_world_size, get_pipeline_model_parallel_world_size,
get_pipeline_model_parallel_rank,
get_virtual_pipeline_model_parallel_rank, set_virtual_pipeline_model_parallel_rank,
get_data_parallel_world_size, get_data_parallel_world_size,
get_data_parallel_rank,
get_global_memory_buffer,
get_num_layers,
) )
from megatron.core import tensor_parallel from megatron.core import tensor_parallel
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
import torch import torch
from typing import Optional from typing import Optional
from .utils import GlobalMemoryBuffer
# Intra-layer model parallel group that the current rank belongs to. # Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None _TENSOR_MODEL_PARALLEL_GROUP = None
# Inter-layer model parallel group that the current rank belongs to. # Inter-layer model parallel group that the current rank belongs to.
...@@ -42,7 +44,8 @@ _PIPELINE_GLOBAL_RANKS = None ...@@ -42,7 +44,8 @@ _PIPELINE_GLOBAL_RANKS = None
# rank when broadcasting weights from src to all other data parallel ranks # rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None _DATA_PARALLEL_GLOBAL_RANKS = None
# Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER = None
def is_unitialized(): def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization""" """Useful for code segments that may be accessed with or without mpu initialization"""
...@@ -195,6 +198,12 @@ def initialize_model_parallel( ...@@ -195,6 +198,12 @@ def initialize_model_parallel(
if rank in ranks: if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
# Initialize global memory buffer
# This isn't really "parallel state" but there isn't another good place to
# put this. If we end up with a more generic initialization of megatron-core
# we could stick it there
_set_global_memory_buffer()
def model_parallel_is_initialized(): def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized.""" """Check if model and data parallel groups are initialized."""
...@@ -506,6 +515,18 @@ def get_data_parallel_rank(): ...@@ -506,6 +515,18 @@ def get_data_parallel_rank():
"""Return my rank for the data parallel group.""" """Return my rank for the data parallel group."""
return torch.distributed.get_rank(group=get_data_parallel_group()) return torch.distributed.get_rank(group=get_data_parallel_group())
def _set_global_memory_buffer():
"""Initialize global buffer"""
global _GLOBAL_MEMORY_BUFFER
assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized'
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
def get_global_memory_buffer():
assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
return _GLOBAL_MEMORY_BUFFER
def destroy_model_parallel(): def destroy_model_parallel():
"""Set the groups to none.""" """Set the groups to none."""
global _MODEL_PARALLEL_GROUP global _MODEL_PARALLEL_GROUP
......
from .cross_entropy import vocab_parallel_cross_entropy from .cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data from .data import broadcast_data
from .layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes,
param_is_not_tensor_parallel_duplicate,
linear_with_grad_accumulation_and_async_allreduce
)
from .mappings import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
scatter_to_tensor_model_parallel_region,
scatter_to_sequence_parallel_region,
)
from .random import (
checkpoint,
get_cuda_rng_tracker,
model_parallel_cuda_manual_seed
)
from .utils import split_tensor_along_last_dim
__all__ = [ __all__ = [
# cross_entropy.py # cross_entropy.py
"vocab_parallel_cross_entropy", "vocab_parallel_cross_entropy",
# data.py # data.py
"broadcast_data", "broadcast_data",
#layers.py
"ColumnParallelLinear",
"RowParallelLinear",
"VocabParallelEmbedding",
"set_defaults_if_not_set_tensor_model_parallel_attributes",
"copy_tensor_model_parallel_attributes",
"param_is_not_tensor_parallel_duplicate",
"linear_with_grad_accumulation_and_async_allreduce",
# mappings.py
"copy_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region",
"gather_from_sequence_parallel_region",
# "reduce_from_tensor_model_parallel_region",
"scatter_to_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region",
# random.py
"checkpoint",
"get_cuda_rng_tracker",
"model_parallel_cuda_manual_seed",
# utils.py
"split_tensor_along_last_dim",
] ]
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch # Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch # repo: https://github.com/pytorch/pytorch
import math import math
from typing import Optional
import warnings
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.nn.init as init import torch.nn.init as init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from .initialize import get_tensor_model_parallel_rank from megatron.core.parallel_state import (
from .initialize import get_tensor_model_parallel_world_size get_tensor_model_parallel_rank,
from .initialize import get_tensor_model_parallel_group get_tensor_model_parallel_world_size,
from .mappings import copy_to_tensor_model_parallel_region get_tensor_model_parallel_group,
from .mappings import gather_from_tensor_model_parallel_region get_global_memory_buffer,
from .mappings import gather_from_sequence_parallel_region )
from .mappings import reduce_from_tensor_model_parallel_region from .mappings import (
from .mappings import scatter_to_tensor_model_parallel_region copy_to_tensor_model_parallel_region,
from .mappings import reduce_scatter_to_sequence_parallel_region gather_from_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region,
)
from .random import get_cuda_rng_tracker from .random import get_cuda_rng_tracker
from .utils import divide from .utils import (
from .utils import split_tensor_along_last_dim divide,
from .utils import VocabUtility split_tensor_along_last_dim,
from megatron import get_args, get_global_memory_buffer VocabUtility,
)
_grad_accum_fusion_available = True
try:
import fused_weight_gradient_mlp_cuda
except ImportError:
_grad_accum_fusion_available = False
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim': -1, 'partition_dim': -1,
...@@ -81,7 +93,8 @@ def _initialize_affine_weight_gpu(weight, init_method, ...@@ -81,7 +93,8 @@ def _initialize_affine_weight_gpu(weight, init_method,
def _initialize_affine_weight_cpu(weight, output_size, input_size, def _initialize_affine_weight_cpu(weight, output_size, input_size,
per_partition_size, partition_dim, per_partition_size, partition_dim,
init_method, stride=1, init_method, stride=1,
return_master_weight=False): return_master_weight=False,
*, params_dtype=torch.float32):
"""Initialize affine weight for model parallel. """Initialize affine weight for model parallel.
Build the master weight on all processes and scatter Build the master weight on all processes and scatter
...@@ -97,8 +110,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size, ...@@ -97,8 +110,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
dtype=torch.float, dtype=torch.float,
requires_grad=False) requires_grad=False)
init_method(master_weight) init_method(master_weight)
args = get_args() master_weight = master_weight.to(dtype=params_dtype)
master_weight = master_weight.to(dtype=args.params_dtype)
# Split and copy # Split and copy
per_partition_per_stride_size = divide(per_partition_size, stride) per_partition_per_stride_size = divide(per_partition_size, stride)
...@@ -123,11 +135,19 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -123,11 +135,19 @@ class VocabParallelEmbedding(torch.nn.Module):
Arguments: Arguments:
num_embeddings: vocabulary size. num_embeddings: vocabulary size.
embedding_dim: size of hidden state. embedding_dim: size of hidden state.
Keyword Arguments:
init_method: method to initialize weights. init_method: method to initialize weights.
params_dtype
use_cpu_initialization
perform_initialization
""" """
def __init__(self, num_embeddings, embedding_dim, def __init__(self, num_embeddings: int, embedding_dim: int, *,
init_method=init.xavier_normal_): init_method=init.xavier_normal_,
params_dtype: torch.dtype=torch.float32,
use_cpu_initialization: bool=False,
perform_initialization: bool=True):
super(VocabParallelEmbedding, self).__init__() super(VocabParallelEmbedding, self).__init__()
# Keep the input dimensions. # Keep the input dimensions.
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
...@@ -149,20 +169,20 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -149,20 +169,20 @@ class VocabParallelEmbedding(torch.nn.Module):
self.vocab_start_index self.vocab_start_index
# Allocate weights and initialize. # Allocate weights and initialize.
args = get_args() if use_cpu_initialization:
if args.use_cpu_initialization:
self.weight = Parameter(torch.empty( self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim, self.num_embeddings_per_partition, self.embedding_dim,
dtype=args.params_dtype)) dtype=params_dtype))
if args.perform_initialization: if perform_initialization:
_initialize_affine_weight_cpu( _initialize_affine_weight_cpu(
self.weight, self.num_embeddings, self.embedding_dim, self.weight, self.num_embeddings, self.embedding_dim,
self.num_embeddings_per_partition, 0, init_method) self.num_embeddings_per_partition, 0, init_method,
params_dtype=params_dtype)
else: else:
self.weight = Parameter(torch.empty( self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim, self.num_embeddings_per_partition, self.embedding_dim,
device=torch.cuda.current_device(), dtype=args.params_dtype)) device=torch.cuda.current_device(), dtype=params_dtype))
if args.perform_initialization: if perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method, _initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=1) partition_dim=0, stride=1)
...@@ -203,7 +223,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -203,7 +223,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce ctx.async_grad_allreduce = async_grad_allreduce
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
if sequence_parallel: if sequence_parallel:
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size()) dim_size = list(input.size())
...@@ -228,7 +248,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -228,7 +248,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
def backward(ctx, grad_output): def backward(ctx, grad_output):
input, weight = ctx.saved_tensors input, weight = ctx.saved_tensors
use_bias = ctx.use_bias use_bias = ctx.use_bias
if ctx.sequence_parallel: if ctx.sequence_parallel:
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size()) dim_size = list(input.size())
...@@ -257,7 +277,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -257,7 +277,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
grad_output.shape[2]) grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input = total_input.view(total_input.shape[0] * total_input.shape[1],
total_input.shape[2]) total_input.shape[2])
if ctx.async_grad_allreduce: if ctx.async_grad_allreduce:
# Asynchronous all-reduce # Asynchronous all-reduce
handle = torch.distributed.all_reduce( handle = torch.distributed.all_reduce(
...@@ -265,7 +285,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -265,7 +285,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
# Delay the start of weight gradient computation shortly (3us) to have # Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated # all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1 _ = torch.empty(1, device=grad_output.device) + 1
if ctx.sequence_parallel: if ctx.sequence_parallel:
assert not ctx.async_grad_allreduce assert not ctx.async_grad_allreduce
dim_size = list(input.size()) dim_size = list(input.size())
...@@ -273,17 +293,16 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -273,17 +293,16 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False) requires_grad=False)
# reduce_scatter # reduce_scatter
handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input, handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input,
group=get_tensor_model_parallel_group(), group=get_tensor_model_parallel_group(),
async_op=True) async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have # Delay the start of weight gradient computation shortly (3us) to have
# reduce scatter scheduled first and have GPU resources allocated # reduce scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1 _ = torch.empty(1, device=grad_output.device) + 1
if ctx.gradient_accumulation_fusion: if ctx.gradient_accumulation_fusion:
import fused_dense_cuda fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad)
fused_dense_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad)
grad_weight = None grad_weight = None
else: else:
grad_weight = grad_output.t().matmul(total_input) grad_weight = grad_output.t().matmul(total_input)
...@@ -298,6 +317,25 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -298,6 +317,25 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None, None return grad_input, grad_weight, grad_bias, None, None, None
def linear_with_grad_accumulation_and_async_allreduce(
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
gradient_accumulation_fusion: bool,
async_grad_allreduce: bool,
sequence_parallel_enabled: bool,
) -> torch.Tensor:
args = [
input,
weight,
bias,
gradient_accumulation_fusion,
async_grad_allreduce,
sequence_parallel_enabled,
]
with torch.cuda.amp.autocast(enabled=False):
return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
class ColumnParallelLinear(torch.nn.Module): class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism. """Linear layer with column parallelism.
...@@ -308,6 +346,8 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -308,6 +346,8 @@ class ColumnParallelLinear(torch.nn.Module):
Arguments: Arguments:
input_size: first dimension of matrix A. input_size: first dimension of matrix A.
output_size: second dimension of matrix A. output_size: second dimension of matrix A.
Keyword Arguments
bias: If true, add bias bias: If true, add bias
gather_output: If true, call all-gather on output and make Y available gather_output: If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output to all GPUs, otherwise, every GPU will have its output
...@@ -321,12 +361,25 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -321,12 +361,25 @@ class ColumnParallelLinear(torch.nn.Module):
skip_bias_add: This was added to enable performance optimations where bias skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip can be fused with other elementwise operations. we skip
adding bias but instead return it. adding bias but instead return it.
async_tensor_model_parallel_allreduce:
params_dtype:
use_cpu_initialization:
gradient_accumulation_fusion:
sequence_parallel_enabled:
""" """
def __init__(self, input_size, output_size, bias=True, gather_output=True, def __init__(self, input_size, output_size, *,
bias=True, gather_output=True,
init_method=init.xavier_normal_, stride=1, init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False, keep_master_weight_for_test=False,
skip_bias_add=False): skip_bias_add=False,
async_tensor_model_parallel_allreduce=True,
params_dtype=torch.float32,
use_cpu_initialization=False,
perform_initialization=True,
gradient_accumulation_fusion=False,
sequence_parallel_enabled: bool = False,
):
super(ColumnParallelLinear, self).__init__() super(ColumnParallelLinear, self).__init__()
# Keep input parameters # Keep input parameters
...@@ -342,12 +395,11 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -342,12 +395,11 @@ class ColumnParallelLinear(torch.nn.Module):
# Note: torch.nn.functional.linear performs XA^T + b and as a result # Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose. # we allocate the transpose.
# Initialize weight. # Initialize weight.
args = get_args() if use_cpu_initialization:
if args.use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size_per_partition, self.weight = Parameter(torch.empty(self.output_size_per_partition,
self.input_size, self.input_size,
dtype=args.params_dtype)) dtype=params_dtype))
if args.perform_initialization: if perform_initialization:
self.master_weight = _initialize_affine_weight_cpu( self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size, self.weight, self.output_size, self.input_size,
self.output_size_per_partition, 0, init_method, self.output_size_per_partition, 0, init_method,
...@@ -355,51 +407,87 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -355,51 +407,87 @@ class ColumnParallelLinear(torch.nn.Module):
else: else:
self.weight = Parameter(torch.empty( self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size, self.output_size_per_partition, self.input_size,
device=torch.cuda.current_device(), dtype=args.params_dtype)) device=torch.cuda.current_device(), dtype=params_dtype))
if args.perform_initialization: if perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method, _initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=stride) partition_dim=0, stride=stride)
if bias: if bias:
if args.use_cpu_initialization: if use_cpu_initialization:
self.bias = Parameter(torch.empty( self.bias = Parameter(torch.empty(
self.output_size_per_partition, dtype=args.params_dtype)) self.output_size_per_partition, dtype=params_dtype))
else: else:
self.bias = Parameter(torch.empty( self.bias = Parameter(torch.empty(
self.output_size_per_partition, self.output_size_per_partition,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=args.params_dtype)) dtype=params_dtype))
set_tensor_model_parallel_attributes(self.bias, True, 0, stride) set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
self.bias.zero_() self.bias.zero_()
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.async_tensor_model_parallel_allreduce = ( self.async_tensor_model_parallel_allreduce = (
args.async_tensor_model_parallel_allreduce and async_tensor_model_parallel_allreduce and
world_size > 1)
self.sequence_parallel = (
args.sequence_parallel and
world_size > 1) world_size > 1)
assert not self.async_tensor_model_parallel_allreduce or \ if sequence_parallel_enabled:
not self.sequence_parallel if world_size <= 1:
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion warnings.warn(
f"`sequence_parallel_enabled` is set to `True`, but tensor model parallel size is {world_size}. "
f"Disabling sequence parallel."
)
sequence_parallel_enabled = False
self.sequence_parallel_enabled = sequence_parallel_enabled
if gradient_accumulation_fusion:
if not _grad_accum_fusion_available:
# Basically, megatron.core users are expected to install APEX's
# `--cpp_ext` and `--cuda_ext`. The example installation command is as follows:
# `pip install --global-option="--cpp_ext" --global-option="--cuda_ext ."
# at the root of APEX repository.
warnings.warn(
"`gradient_accumulation_fusion` is set to `True` but "
"the custom CUDA extension of `fused_weight_gradient_mlp_cuda` module not "
"found. Thus `gradient_accumulation_fusion` set to `False`. "
"Note that the extension requires CUDA>=11."
)
gradient_accumulation_fusion = False
self.gradient_accumulation_fusion = gradient_accumulation_fusion
if self.async_tensor_model_parallel_allreduce and self.sequence_parallel_enabled:
raise RuntimeError("`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` cannot be enabled at the same time.")
def forward(self, input_): def forward(self, input_):
"""Forward of ColumnParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
Returns:
- output
- bias
"""
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
if self.async_tensor_model_parallel_allreduce or \ if self.async_tensor_model_parallel_allreduce or \
self.sequence_parallel: self.sequence_parallel_enabled:
input_parallel = input_ input_parallel = input_
else: else:
input_parallel = copy_to_tensor_model_parallel_region(input_) input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply( output_parallel = linear_with_grad_accumulation_and_async_allreduce(
input_parallel, self.weight, bias, self.gradient_accumulation_fusion, input=input_parallel,
self.async_tensor_model_parallel_allreduce, self.sequence_parallel) weight=self.weight,
bias=bias,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=self.async_tensor_model_parallel_allreduce,
sequence_parallel_enabled=self.sequence_parallel_enabled,
)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
assert not self.sequence_parallel assert not self.sequence_parallel_enabled
output = gather_from_tensor_model_parallel_region(output_parallel) output = gather_from_tensor_model_parallel_region(output_parallel)
else: else:
output = output_parallel output = output_parallel
...@@ -422,6 +510,8 @@ class RowParallelLinear(torch.nn.Module): ...@@ -422,6 +510,8 @@ class RowParallelLinear(torch.nn.Module):
Arguments: Arguments:
input_size: first dimension of matrix A. input_size: first dimension of matrix A.
output_size: second dimension of matrix A. output_size: second dimension of matrix A.
Keyword Arguments:
bias: If true, add bias. Note that bias is not parallelized. bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split split across the GPUs and we do not split
...@@ -435,13 +525,24 @@ class RowParallelLinear(torch.nn.Module): ...@@ -435,13 +525,24 @@ class RowParallelLinear(torch.nn.Module):
skip_bias_add: This was added to enable performance optimization where bias skip_bias_add: This was added to enable performance optimization where bias
can be fused with other elementwise operations. We skip can be fused with other elementwise operations. We skip
adding bias but instead return it. adding bias but instead return it.
params_dtype:
use_cpu_initialization:
perform_initialization:
gradient_accumulation_fusion:
sequence_parallel_enabled:
""" """
def __init__(self, input_size, output_size, bias=True, def __init__(self, input_size, output_size, *,
input_is_parallel=False, bias=True, input_is_parallel=False,
init_method=init.xavier_normal_, stride=1, init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False, keep_master_weight_for_test=False,
skip_bias_add=False): skip_bias_add=False,
params_dtype=torch.float32,
use_cpu_initialization=False,
perform_initialization=True,
gradient_accumulation_fusion=False,
sequence_parallel_enabled: bool = False,
):
super(RowParallelLinear, self).__init__() super(RowParallelLinear, self).__init__()
# Keep input parameters # Keep input parameters
...@@ -452,61 +553,78 @@ class RowParallelLinear(torch.nn.Module): ...@@ -452,61 +553,78 @@ class RowParallelLinear(torch.nn.Module):
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size) self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.gradient_accumulation_fusion = gradient_accumulation_fusion
self.sequence_parallel_enabled = sequence_parallel_enabled
if self.sequence_parallel_enabled and not self.input_is_parallel:
raise RuntimeError("To enable `sequence_parallel_enabled`, `input_is_parallel` must be `True`")
# Parameters. # Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result # Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose. # we allocate the transpose.
# Initialize weight. # Initialize weight.
args = get_args() if use_cpu_initialization:
if args.use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size, self.weight = Parameter(torch.empty(self.output_size,
self.input_size_per_partition, self.input_size_per_partition,
dtype=args.params_dtype)) dtype=params_dtype))
if args.perform_initialization: if perform_initialization:
self.master_weight = _initialize_affine_weight_cpu( self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size, self.weight, self.output_size, self.input_size,
self.input_size_per_partition, 1, init_method, self.input_size_per_partition, 1, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test) stride=stride, return_master_weight=keep_master_weight_for_test,
params_dtype=params_dtype)
else: else:
self.weight = Parameter(torch.empty( self.weight = Parameter(torch.empty(
self.output_size, self.input_size_per_partition, self.output_size, self.input_size_per_partition,
device=torch.cuda.current_device(), dtype=args.params_dtype)) device=torch.cuda.current_device(), dtype=params_dtype))
if args.perform_initialization: if perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method, _initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=1, stride=stride) partition_dim=1, stride=stride)
if bias: if bias:
if args.use_cpu_initialization: if use_cpu_initialization:
self.bias = Parameter(torch.empty(self.output_size, self.bias = Parameter(torch.empty(self.output_size,
dtype=args.params_dtype)) dtype=params_dtype))
else: else:
self.bias = Parameter(torch.empty( self.bias = Parameter(torch.empty(
self.output_size, device=torch.cuda.current_device(), self.output_size, device=torch.cuda.current_device(),
dtype=args.params_dtype)) dtype=params_dtype))
setattr(self.bias, 'sequence_parallel', args.sequence_parallel) setattr(self.bias, 'sequence_parallel', sequence_parallel_enabled)
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
self.bias.zero_() self.bias.zero_()
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.sequence_parallel = args.sequence_parallel
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
def forward(self, input_): def forward(self, input_):
"""Forward of RowParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
Returns:
- output
- bias
"""
# Set up backprop all-reduce. # Set up backprop all-reduce.
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
else: else:
assert not self.sequence_parallel assert not self.sequence_parallel_enabled
input_parallel = scatter_to_tensor_model_parallel_region(input_) input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply( output_parallel = linear_with_grad_accumulation_and_async_allreduce(
input_parallel, self.weight, None, input=input_parallel,
self.gradient_accumulation_fusion, None, None) weight=self.weight,
bias=None,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=False,
sequence_parallel_enabled=False,
)
# All-reduce across all the partitions. # All-reduce across all the partitions.
if self.sequence_parallel: if self.sequence_parallel_enabled:
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
else: else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel) output_ = reduce_from_tensor_model_parallel_region(output_parallel)
......
...@@ -2,7 +2,11 @@ ...@@ -2,7 +2,11 @@
import torch import torch
from .initialize import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank from megatron.core.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group,
)
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch # Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch # repo: https://github.com/pytorch/pytorch
...@@ -11,12 +10,12 @@ from torch import _C ...@@ -11,12 +10,12 @@ from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager from torch.cuda import _lazy_call, device as device_ctx_manager
from torch.utils.checkpoint import detach_variable from torch.utils.checkpoint import detach_variable
from megatron.memory import allocate_mem_buff from megatron.core.parallel_state import (
get_data_parallel_rank,
from .initialize import get_data_parallel_rank get_tensor_model_parallel_group,
from .initialize import get_tensor_model_parallel_group get_tensor_model_parallel_rank,
from .initialize import get_tensor_model_parallel_rank get_tensor_model_parallel_world_size,
from .initialize import get_tensor_model_parallel_world_size )
# Default name for the model parallel rng tracker. # Default name for the model parallel rng tracker.
...@@ -89,85 +88,6 @@ def gather_split_1d_tensor(tensor): ...@@ -89,85 +88,6 @@ def gather_split_1d_tensor(tensor):
return gathered return gathered
def _kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor.
View tensors have the undesirable side-affect of retaining a reference
to the originally-viewed tensor, even after manually setting the '.data'
field. This method creates a new tensor that links to the old tensor's
data, without linking the viewed tensor, referenced via the '._base'
field.
'''
out = torch.empty(
(1,),
dtype = inp.dtype,
device = inp.device,
requires_grad = requires_grad,
)
out.data = inp.data
return out
class MakeViewlessTensor(torch.autograd.Function):
'''
Autograd function to make a viewless tensor.
This function should be used in cases where the computation graph needs
to be propagated, but we only want a viewless tensor (e.g.,
ParallelTransformer's hidden_states). Call this function by passing
'keep_graph = True' to 'make_viewless_tensor()'.
'''
@staticmethod
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def make_viewless_tensor(inp, requires_grad, keep_graph):
'''
Entry-point for creating viewless tensors.
This method should be used, rather than calling 'MakeViewlessTensor'
or '_kernel_make_viewless_tensor' directly. This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor.
'''
# return tensor as-is, if not a 'view'
if inp._base is None:
return inp
# create viewless tensor
if keep_graph:
return MakeViewlessTensor.apply(inp, requires_grad)
else:
return _kernel_make_viewless_tensor(inp, requires_grad)
def assert_viewless_tensor(tensor, extra_msg = None):
'''Assert that a tensor is not a view (i.e., its '._base' field is
not set).'''
if isinstance(tensor, list):
[ assert_viewless_tensor(t) for t in tensor ]
return tensor
if not isinstance(tensor, torch.Tensor):
return tensor
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). %s"
) % extra_msg
return tensor
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
'''Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
'''
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
tensor.data = new_data_tensor
class CudaRNGStatesTracker: class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states. """Tracker for the cuda RNG states.
......
...@@ -5,19 +5,6 @@ from typing import List, Sequence ...@@ -5,19 +5,6 @@ from typing import List, Sequence
from megatron.core.utils import divide from megatron.core.utils import divide
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, '{} is not divisible by {}'.format(
numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim( def split_tensor_along_last_dim(
tensor: torch.Tensor, tensor: torch.Tensor,
num_partitions: int, num_partitions: int,
......
"""Utility functions used through Megatron core""" """Utility functions used throughout Megatron core"""
from functools import reduce
import operator
import torch import torch
from megatron.core import parallel_state from megatron.core import parallel_state
...@@ -46,3 +49,101 @@ def gather_split_1d_tensor(tensor): ...@@ -46,3 +49,101 @@ def gather_split_1d_tensor(tensor):
group=parallel_state.get_tensor_model_parallel_group() group=parallel_state.get_tensor_model_parallel_group()
) )
return gathered return gathered
class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
are not used concurrently."""
def __init__(self):
self.buffer = {}
def get_tensor(self, tensor_shape, dtype, name):
required_len = reduce(operator.mul, tensor_shape, 1)
if self.buffer.get((name, dtype), None) is None or \
self.buffer[(name, dtype)].numel() < required_len:
self.buffer[(name, dtype)] = \
torch.empty(required_len,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False)
return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
def _kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor.
View tensors have the undesirable side-affect of retaining a reference
to the originally-viewed tensor, even after manually setting the '.data'
field. This method creates a new tensor that links to the old tensor's
data, without linking the viewed tensor, referenced via the '._base'
field.
'''
out = torch.empty(
(1,),
dtype = inp.dtype,
device = inp.device,
requires_grad = requires_grad,
)
out.data = inp.data
return out
class MakeViewlessTensor(torch.autograd.Function):
'''
Autograd function to make a viewless tensor.
This function should be used in cases where the computation graph needs
to be propagated, but we only want a viewless tensor (e.g.,
ParallelTransformer's hidden_states). Call this function by passing
'keep_graph = True' to 'make_viewless_tensor()'.
'''
@staticmethod
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def make_viewless_tensor(inp, requires_grad, keep_graph):
'''
Entry-point for creating viewless tensors.
This method should be used, rather than calling 'MakeViewlessTensor'
or '_kernel_make_viewless_tensor' directly. This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor.
'''
# return tensor as-is, if not a 'view'
if inp._base is None:
return inp
# create viewless tensor
if keep_graph:
return MakeViewlessTensor.apply(inp, requires_grad)
else:
return _kernel_make_viewless_tensor(inp, requires_grad)
def assert_viewless_tensor(tensor, extra_msg = None):
'''Assert that a tensor is not a view (i.e., its '._base' field is
not set).'''
if isinstance(tensor, list):
[ assert_viewless_tensor(t) for t in tensor ]
return tensor
if not isinstance(tensor, torch.Tensor):
return tensor
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). %s"
) % extra_msg
return tensor
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
'''Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
'''
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
tensor.data = new_data_tensor
...@@ -4,8 +4,6 @@ ...@@ -4,8 +4,6 @@
import os import os
import sys import sys
from functools import reduce
import operator
import torch import torch
from megatron import dist_signal_handler from megatron import dist_signal_handler
...@@ -20,7 +18,6 @@ _GLOBAL_TENSORBOARD_WRITER = None ...@@ -20,7 +18,6 @@ _GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_ADLR_AUTORESUME = None _GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None _GLOBAL_TIMERS = None
_GLOBAL_SIGNAL_HANDLER = None _GLOBAL_SIGNAL_HANDLER = None
_GLOBAL_MEMORY_BUFFER = None
def get_args(): def get_args():
"""Return arguments.""" """Return arguments."""
...@@ -70,11 +67,6 @@ def get_signal_handler(): ...@@ -70,11 +67,6 @@ def get_signal_handler():
return _GLOBAL_SIGNAL_HANDLER return _GLOBAL_SIGNAL_HANDLER
def get_global_memory_buffer():
_ensure_var_is_initialized(_GLOBAL_MEMORY_BUFFER, 'global memory buffer')
return _GLOBAL_MEMORY_BUFFER
def _set_signal_handler(): def _set_signal_handler():
global _GLOBAL_SIGNAL_HANDLER global _GLOBAL_SIGNAL_HANDLER
_ensure_var_is_not_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler') _ensure_var_is_not_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler')
...@@ -96,7 +88,6 @@ def set_global_variables(args): ...@@ -96,7 +88,6 @@ def set_global_variables(args):
_set_tensorboard_writer(args) _set_tensorboard_writer(args)
_set_adlr_autoresume(args) _set_adlr_autoresume(args)
_set_timers(args) _set_timers(args)
_set_global_memory_buffer()
if args.exit_signal_handler: if args.exit_signal_handler:
_set_signal_handler() _set_signal_handler()
...@@ -176,13 +167,6 @@ def _set_timers(args): ...@@ -176,13 +167,6 @@ def _set_timers(args):
_GLOBAL_TIMERS = Timers(args.timing_log_level, args.timing_log_option) _GLOBAL_TIMERS = Timers(args.timing_log_level, args.timing_log_option)
def _set_global_memory_buffer():
"""Initialize global buffer"""
global _GLOBAL_MEMORY_BUFFER
_ensure_var_is_not_initialized(_GLOBAL_MEMORY_BUFFER, 'global memory buffer')
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
def _ensure_var_is_initialized(var, name): def _ensure_var_is_initialized(var, name):
"""Make sure the input variable is not None.""" """Make sure the input variable is not None."""
assert var is not None, '{} is not initialized.'.format(name) assert var is not None, '{} is not initialized.'.format(name)
...@@ -194,22 +178,3 @@ def _ensure_var_is_not_initialized(var, name): ...@@ -194,22 +178,3 @@ def _ensure_var_is_not_initialized(var, name):
class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
are not used concurrently."""
def __init__(self):
self.buffer = {}
def get_tensor(self, tensor_shape, dtype, name):
required_len = reduce(operator.mul, tensor_shape, 1)
if self.buffer.get((name, dtype), None) is None or \
self.buffer[(name, dtype)].numel() < required_len:
self.buffer[(name, dtype)] = \
torch.empty(required_len,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False)
return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
...@@ -219,7 +219,7 @@ def _set_random_seed(seed_, data_parallel_random_init=False): ...@@ -219,7 +219,7 @@ def _set_random_seed(seed_, data_parallel_random_init=False):
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.device_count() > 0: if torch.cuda.device_count() > 0:
mpu.model_parallel_cuda_manual_seed(seed) core.tensor_parallel.model_parallel_cuda_manual_seed(seed)
else: else:
raise ValueError('Seed ({}) should be a positive integer.'.format(seed)) raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
......
...@@ -10,7 +10,7 @@ from torch.nn.parameter import Parameter ...@@ -10,7 +10,7 @@ from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
import importlib import importlib
from megatron.mpu import make_viewless_tensor from megatron.core.utils import make_viewless_tensor
try: try:
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import core
from .module import MegatronModule from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType from megatron.model.enums import LayerType, AttnMaskType
from megatron.model.transformer import ParallelTransformer from megatron.model.transformer import ParallelTransformer
...@@ -22,24 +22,27 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -22,24 +22,27 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
if args.async_tensor_model_parallel_allreduce or\ if args.async_tensor_model_parallel_allreduce or\
args.sequence_parallel: args.sequence_parallel:
input_parallel = input_ input_parallel = input_
model_parallel = mpu.get_tensor_model_parallel_world_size() > 1 model_parallel = core.get_tensor_model_parallel_world_size() > 1
async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \ async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
model_parallel and not args.sequence_parallel model_parallel and not args.sequence_parallel
else: else:
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_) input_parallel = core.tensor_parallel.copy_to_tensor_model_parallel_region(input_)
async_grad_allreduce = False async_grad_allreduce = False
# Matrix multiply. # Matrix multiply.
logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply( logits_parallel = core.tensor_parallel.linear_with_grad_accumulation_and_async_allreduce(
input_parallel, word_embeddings_weight, bias, input=input_parallel,
args.gradient_accumulation_fusion, weight=word_embeddings_weight,
async_grad_allreduce, args.sequence_parallel) bias=bias,
gradient_accumulation_fusion=args.gradient_accumulation_fusion,
async_grad_allreduce=async_grad_allreduce,
sequence_parallel_enabled=args.sequence_parallel)
# Gather if needed. # Gather if needed.
if parallel_output: if parallel_output:
return logits_parallel return logits_parallel
return mpu.gather_from_tensor_model_parallel_region(logits_parallel) return core.tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(num_tokentypes, add_pooler, def get_language_model(num_tokentypes, add_pooler,
...@@ -103,7 +106,7 @@ class Pooler(MegatronModule): ...@@ -103,7 +106,7 @@ class Pooler(MegatronModule):
# gather data along sequence dimensions # gather data along sequence dimensions
# same pooler is run on all tensor parallel nodes # same pooler is run on all tensor parallel nodes
if self.sequence_parallel: if self.sequence_parallel:
hidden_states = mpu.gather_from_sequence_parallel_region( hidden_states = core.tensor_parallel.gather_from_sequence_parallel_region(
hidden_states, hidden_states,
tensor_parallel_output_grad=False) tensor_parallel_output_grad=False)
...@@ -143,9 +146,13 @@ class Embedding(MegatronModule): ...@@ -143,9 +146,13 @@ class Embedding(MegatronModule):
args = get_args() args = get_args()
# Word embeddings (parallel). # Word embeddings (parallel).
self.word_embeddings = mpu.VocabParallelEmbedding( self.word_embeddings = core.tensor_parallel.VocabParallelEmbedding(
vocab_size, self.hidden_size, vocab_size, self.hidden_size,
init_method=self.init_method) init_method=self.init_method,
params_dtype=args.params_dtype,
use_cpu_initialization=args.use_cpu_initialization,
perform_initialization=args.perform_initialization
)
self._word_embeddings_key = 'word_embeddings' self._word_embeddings_key = 'word_embeddings'
# Position embedding (serial). # Position embedding (serial).
...@@ -222,8 +229,8 @@ class Embedding(MegatronModule): ...@@ -222,8 +229,8 @@ class Embedding(MegatronModule):
# Dropout. # Dropout.
if self.sequence_parallel: if self.sequence_parallel:
embeddings = mpu.scatter_to_sequence_parallel_region(embeddings) embeddings = core.tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
with mpu.get_cuda_rng_tracker().fork(): with core.tensor_parallel.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings) embeddings = self.embedding_dropout(embeddings)
else: else:
embeddings = self.embedding_dropout(embeddings) embeddings = self.embedding_dropout(embeddings)
......
...@@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter ...@@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron import core
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) _FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
...@@ -76,9 +77,12 @@ class MegatronModule(torch.nn.Module): ...@@ -76,9 +77,12 @@ class MegatronModule(torch.nn.Module):
self._word_embeddings_for_head_key = 'word_embeddings_for_head' self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first # set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below. # stage's weights using all_reduce below.
self.word_embeddings = mpu.VocabParallelEmbedding( self.word_embeddings = core.tensor_parallel.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size, args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std)) init_method=init_method_normal(args.init_method_std),
params_dtype=args.params_dtype,
use_cpu_initialization=args.use_cpu_initialization,
perform_initialization=args.perform_initialization)
self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True self.word_embeddings.weight.shared = True
......
...@@ -6,8 +6,9 @@ from contextlib import nullcontext ...@@ -6,8 +6,9 @@ from contextlib import nullcontext
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_timers, get_args, get_global_memory_buffer from megatron import get_timers, get_args
from megatron import mpu from megatron.core import get_global_memory_buffer
from megatron import core
from .module import MegatronModule from .module import MegatronModule
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
from megatron.model import LayerNorm from megatron.model import LayerNorm
...@@ -32,7 +33,7 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu ...@@ -32,7 +33,7 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
""" """
class DropPath(MegatronModule): class DropPath(MegatronModule):
"""Drop paths (Stochastic Depth) per sample """Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks). (when applied in main path of residual blocks).
""" """
...@@ -52,6 +53,17 @@ class DropPath(MegatronModule): ...@@ -52,6 +53,17 @@ class DropPath(MegatronModule):
output = hidden_state.div(keep_prob) * random_tensor output = hidden_state.div(keep_prob) * random_tensor
return output return output
def _args_to_kwargs():
args = get_args()
common_kwargs = {
"params_dtype": args.params_dtype,
"use_cpu_initialization": args.use_cpu_initialization,
"perform_initialization": args.perform_initialization,
"gradient_accumulation_fusion": args.gradient_accumulation_fusion,
"sequence_parallel_enabled": args.sequence_parallel,
}
return common_kwargs
class ParallelMLP(MegatronModule): class ParallelMLP(MegatronModule):
"""MLP. """MLP.
...@@ -65,13 +77,16 @@ class ParallelMLP(MegatronModule): ...@@ -65,13 +77,16 @@ class ParallelMLP(MegatronModule):
super(ParallelMLP, self).__init__() super(ParallelMLP, self).__init__()
args = get_args() args = get_args()
# Project to 4h. # Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear( self.dense_h_to_4h = core.tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
args.ffn_hidden_size, args.ffn_hidden_size,
gather_output=False, gather_output=False,
init_method=init_method, init_method=init_method,
skip_bias_add=True) skip_bias_add=True,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
self.bias_gelu_fusion = args.bias_gelu_fusion self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu self.activation_func = F.gelu
...@@ -81,12 +96,13 @@ class ParallelMLP(MegatronModule): ...@@ -81,12 +96,13 @@ class ParallelMLP(MegatronModule):
self.activation_func = erf_gelu self.activation_func = erf_gelu
# Project back to h. # Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear( self.dense_4h_to_h = core.tensor_parallel.RowParallelLinear(
args.ffn_hidden_size, args.ffn_hidden_size,
args.hidden_size, args.hidden_size,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True,
**_args_to_kwargs())
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -136,7 +152,7 @@ class SwitchMLP(MegatronModule): ...@@ -136,7 +152,7 @@ class SwitchMLP(MegatronModule):
output_total = torch.empty_like(hidden_states) output_total = torch.empty_like(hidden_states)
output_bias_total = torch.empty_like(hidden_states) output_bias_total = torch.empty_like(hidden_states)
#TODO (rprenger) This does each expert in serial, but it could be parallelized #TODO (rprenger) This does each expert in serial, but it could be parallelized
for expert_num, expert in enumerate(self.experts): for expert_num, expert in enumerate(self.experts):
local_indices = (max_ind == expert_num).nonzero() local_indices = (max_ind == expert_num).nonzero()
hidden = hidden_states[local_indices,:] hidden = hidden_states[local_indices,:]
...@@ -173,12 +189,12 @@ class CoreAttention(MegatronModule): ...@@ -173,12 +189,12 @@ class CoreAttention(MegatronModule):
projection_size = args.kv_channels * args.num_attention_heads projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values. # Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size() world_size = core.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(projection_size, self.hidden_size_per_partition = core.utils.divide(projection_size,
world_size) world_size)
self.hidden_size_per_attention_head = mpu.divide( self.hidden_size_per_attention_head = core.utils.divide(
projection_size, args.num_attention_heads) projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide( self.num_attention_heads_per_partition = core.utils.divide(
args.num_attention_heads, world_size) args.num_attention_heads, world_size)
coeff = None coeff = None
...@@ -247,7 +263,7 @@ class CoreAttention(MegatronModule): ...@@ -247,7 +263,7 @@ class CoreAttention(MegatronModule):
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
if not self.sequence_parallel: if not self.sequence_parallel:
with mpu.get_cuda_rng_tracker().fork(): with core.tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs) attention_probs = self.attention_dropout(attention_probs)
else: else:
attention_probs = self.attention_dropout(attention_probs) attention_probs = self.attention_dropout(attention_probs)
...@@ -311,44 +327,52 @@ class ParallelAttention(MegatronModule): ...@@ -311,44 +327,52 @@ class ParallelAttention(MegatronModule):
projection_size = args.kv_channels * args.num_attention_heads projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values. # Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size() world_size = core.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = mpu.divide( self.hidden_size_per_attention_head = core.utils.divide(
projection_size, args.num_attention_heads) projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide( self.num_attention_heads_per_partition = core.utils.divide(
args.num_attention_heads, world_size) args.num_attention_heads, world_size)
# Strided linear layer. # Strided linear layer.
if attention_type == AttnType.self_attn: if attention_type == AttnType.self_attn:
self.query_key_value = mpu.ColumnParallelLinear( self.query_key_value = core.tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
3 * projection_size, 3 * projection_size,
gather_output=False, gather_output=False,
init_method=init_method) init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
else: else:
assert attention_type == AttnType.cross_attn assert attention_type == AttnType.cross_attn
self.query = mpu.ColumnParallelLinear( self.query = core.tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
projection_size, projection_size,
gather_output=False, gather_output=False,
init_method=init_method) init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
self.key_value = mpu.ColumnParallelLinear(
self.key_value = core.tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
2 * projection_size, 2 * projection_size,
gather_output=False, gather_output=False,
init_method=init_method) init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
self.core_attention = CoreAttention(self.layer_number, self.core_attention = CoreAttention(self.layer_number,
self.attn_mask_type) self.attn_mask_type)
self.checkpoint_core_attention = args.recompute_granularity == 'selective' self.checkpoint_core_attention = args.recompute_granularity == 'selective'
# Output. # Output.
self.dense = mpu.RowParallelLinear( self.dense = core.tensor_parallel.RowParallelLinear(
projection_size, projection_size,
args.hidden_size, args.hidden_size,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True,
**_args_to_kwargs())
def _checkpointed_attention_forward(self, query_layer, key_layer, def _checkpointed_attention_forward(self, query_layer, key_layer,
value_layer, attention_mask): value_layer, attention_mask):
...@@ -362,7 +386,7 @@ class ParallelAttention(MegatronModule): ...@@ -362,7 +386,7 @@ class ParallelAttention(MegatronModule):
value_layer, attention_mask) value_layer, attention_mask)
return output_ return output_
hidden_states = mpu.checkpoint( hidden_states = core.tensor_parallel.checkpoint(
custom_forward, custom_forward,
False, query_layer, key_layer, value_layer, attention_mask) False, query_layer, key_layer, value_layer, attention_mask)
...@@ -415,7 +439,7 @@ class ParallelAttention(MegatronModule): ...@@ -415,7 +439,7 @@ class ParallelAttention(MegatronModule):
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, (query_layer,
key_layer, key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) value_layer) = core.tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
else: else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output) mixed_kv_layer, _ = self.key_value(encoder_output)
...@@ -428,7 +452,7 @@ class ParallelAttention(MegatronModule): ...@@ -428,7 +452,7 @@ class ParallelAttention(MegatronModule):
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer, (key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2) value_layer) = core.tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp] # Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states) query_layer, _ = self.query(hidden_states)
...@@ -674,9 +698,9 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -674,9 +698,9 @@ class ParallelTransformerLayer(MegatronModule):
# won't result in memory savings (like the data loader, or # won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this # p2p_communication), it serves to document the origin of this
# 'view' tensor. # 'view' tensor.
output = mpu.make_viewless_tensor(inp = output, output = core.utils.make_viewless_tensor(inp = output,
requires_grad = output.requires_grad, requires_grad = output.requires_grad,
keep_graph = True) keep_graph = True)
else: else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias, out = torch.nn.functional.dropout(mlp_output + mlp_bias,
...@@ -719,7 +743,7 @@ class ParallelTransformer(MegatronModule): ...@@ -719,7 +743,7 @@ class ParallelTransformer(MegatronModule):
def __init__(self, init_method, output_layer_init_method, def __init__(self, init_method, output_layer_init_method,
layer_type=LayerType.encoder, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding, self_attn_mask_type=AttnMaskType.padding,
post_layer_norm=True, post_layer_norm=True,
pre_process=True, post_process=True, pre_process=True, post_process=True,
drop_path_rate=0.0): drop_path_rate=0.0):
super(ParallelTransformer, self).__init__() super(ParallelTransformer, self).__init__()
...@@ -745,7 +769,7 @@ class ParallelTransformer(MegatronModule): ...@@ -745,7 +769,7 @@ class ParallelTransformer(MegatronModule):
self.sequence_parallel = args.sequence_parallel self.sequence_parallel = args.sequence_parallel
# Number of layers. # Number of layers.
self.num_layers = mpu.get_num_layers( self.num_layers = core.get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder) args, args.model_type == ModelType.encoder_and_decoder)
self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)] self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
...@@ -775,21 +799,21 @@ class ParallelTransformer(MegatronModule): ...@@ -775,21 +799,21 @@ class ParallelTransformer(MegatronModule):
# layers to stages like (each list is a model chunk): # layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5] # Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7] # Stage 1: [2, 3] [6, 7]
offset = mpu.get_virtual_pipeline_model_parallel_rank() * ( offset = core.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size) + \ args.num_layers // args.virtual_pipeline_model_parallel_size) + \
(mpu.get_pipeline_model_parallel_rank() * self.num_layers) (core.get_pipeline_model_parallel_rank() * self.num_layers)
else: else:
# Each stage gets a contiguous set of layers. # Each stage gets a contiguous set of layers.
if args.model_type == ModelType.encoder_and_decoder and \ if args.model_type == ModelType.encoder_and_decoder and \
mpu.get_pipeline_model_parallel_world_size() > 1: core.get_pipeline_model_parallel_world_size() > 1:
pipeline_rank = mpu.get_pipeline_model_parallel_rank() pipeline_rank = core.get_pipeline_model_parallel_rank()
if layer_type == LayerType.encoder: if layer_type == LayerType.encoder:
offset = pipeline_rank * self.num_layers offset = pipeline_rank * self.num_layers
else: else:
num_ranks_in_enc = args.pipeline_model_parallel_split_rank num_ranks_in_enc = args.pipeline_model_parallel_split_rank
offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
else: else:
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers offset = core.get_pipeline_model_parallel_rank() * self.num_layers
if self.num_layers == 0: if self.num_layers == 0:
# When a standalone embedding stage is used (e.g., # When a standalone embedding stage is used (e.g.,
...@@ -838,7 +862,7 @@ class ParallelTransformer(MegatronModule): ...@@ -838,7 +862,7 @@ class ParallelTransformer(MegatronModule):
# A method to further reduce memory usage reducing checkpoints. # A method to further reduce memory usage reducing checkpoints.
l = 0 l = 0
while l < self.num_layers: while l < self.num_layers:
hidden_states = mpu.checkpoint( hidden_states = core.tensor_parallel.checkpoint(
custom(l, l + self.recompute_num_layers), custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations, self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
...@@ -850,7 +874,7 @@ class ParallelTransformer(MegatronModule): ...@@ -850,7 +874,7 @@ class ParallelTransformer(MegatronModule):
# A method fully use the device memory removing redundant re-computation. # A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers): for l in range(self.num_layers):
if l < self.recompute_num_layers: if l < self.recompute_num_layers:
hidden_states = mpu.checkpoint( hidden_states = core.tensor_parallel.checkpoint(
custom(l, l + 1), custom(l, l + 1),
self.distribute_saved_activations, self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
...@@ -896,19 +920,19 @@ class ParallelTransformer(MegatronModule): ...@@ -896,19 +920,19 @@ class ParallelTransformer(MegatronModule):
# However, we don't explicitly check mbs == 1 here because # However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input # make_viewless_tensor() has negligible overhead when its input
# is already viewless. # is already viewless.
# #
# - For the 'else' case above, calling make_viewless_tensor() here is # - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator) # likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor() # already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof. # is called here to be future-proof and corner-case-proof.
hidden_states = mpu.make_viewless_tensor( hidden_states = core.utils.make_viewless_tensor(
hidden_states, hidden_states,
requires_grad=True, requires_grad=True,
keep_graph=True, keep_graph=True,
) )
if self.sequence_parallel: if self.sequence_parallel:
rng_context = mpu.get_cuda_rng_tracker().fork() rng_context = core.tensor_parallel.get_cuda_rng_tracker().fork()
else: else:
rng_context = nullcontext() rng_context = nullcontext()
......
...@@ -32,30 +32,6 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi ...@@ -32,30 +32,6 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi
from .initialize import initialize_model_parallel from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized from .initialize import model_parallel_is_initialized
from .layers import LinearWithGradAccumulationAndAsyncCommunication
from .layers import ColumnParallelLinear
from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding
from .layers import (set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes)
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import scatter_to_sequence_parallel_region
from .mappings import gather_from_sequence_parallel_region
from .mappings import reduce_scatter_to_sequence_parallel_region
from .random import checkpoint
from .random import get_cuda_rng_tracker
from .random import model_parallel_cuda_manual_seed
from .random import gather_split_1d_tensor
from .random import split_tensor_into_1d_equal_chunks
from .random import make_viewless_tensor
from .random import assert_viewless_tensor
from .random import safely_set_viewless_tensor_data
from .utils import divide from .utils import divide
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
...@@ -9,7 +9,7 @@ from apex.multi_tensor_apply import multi_tensor_applier ...@@ -9,7 +9,7 @@ from apex.multi_tensor_apply import multi_tensor_applier
import amp_C import amp_C
from megatron.model.module import param_is_not_shared from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
def clip_grad_norm_fp32(parameters, grads_for_norm, def clip_grad_norm_fp32(parameters, grads_for_norm,
......
...@@ -11,7 +11,7 @@ from megatron import get_timers ...@@ -11,7 +11,7 @@ from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.model.module import param_is_not_shared from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
......
...@@ -12,11 +12,12 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors ...@@ -12,11 +12,12 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron import core
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module from megatron.model import Float16Module
from megatron.model.module import param_is_not_shared from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
from megatron.utils import unwrap_model from megatron.utils import unwrap_model
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32 from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
...@@ -527,8 +528,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): ...@@ -527,8 +528,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
# Create a copy # Create a copy
main_param = param.detach().clone().float() main_param = param.detach().clone().float()
# Copy tensor model parallel attributes. # Copy tensor model parallel attributes.
mpu.copy_tensor_model_parallel_attributes(main_param, core.tensor_parallel.copy_tensor_model_parallel_attributes(main_param,
param) param)
if hasattr(param, 'shared'): if hasattr(param, 'shared'):
main_param.shared = param.shared main_param.shared = param.shared
# Replace the optimizer params with the new fp32 copy. # Replace the optimizer params with the new fp32 copy.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment