Commit c2ea914f authored by Jared Casper's avatar Jared Casper
Browse files

Move layers from mpu to core.tensor_parallel.

parent 209f91c9
...@@ -10,7 +10,6 @@ from .global_vars import get_tokenizer ...@@ -10,7 +10,6 @@ from .global_vars import get_tokenizer
from .global_vars import get_tensorboard_writer from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume from .global_vars import get_adlr_autoresume
from .global_vars import get_timers from .global_vars import get_timers
from .global_vars import get_global_memory_buffer
from .initialize import initialize_megatron from .initialize import initialize_megatron
from .utils import (print_rank_0, from .utils import (print_rank_0,
......
...@@ -9,7 +9,7 @@ import numpy as np ...@@ -9,7 +9,7 @@ import numpy as np
import torch import torch
from megatron import (mpu, from megatron import (core,
update_num_microbatches) update_num_microbatches)
from .global_vars import get_args from .global_vars import get_args
from .utils import (unwrap_model, from .utils import (unwrap_model,
...@@ -79,11 +79,11 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, ...@@ -79,11 +79,11 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
# Use both the tensor and pipeline MP rank. # Use both the tensor and pipeline MP rank.
if pipeline_parallel is None: if pipeline_parallel is None:
pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1) pipeline_parallel = (core.get_pipeline_model_parallel_world_size() > 1)
if tensor_rank is None: if tensor_rank is None:
tensor_rank = mpu.get_tensor_model_parallel_rank() tensor_rank = core.get_tensor_model_parallel_rank()
if pipeline_rank is None: if pipeline_rank is None:
pipeline_rank = mpu.get_pipeline_model_parallel_rank() pipeline_rank = core.get_pipeline_model_parallel_rank()
# Use both the tensor and pipeline MP rank. If using the distributed # Use both the tensor and pipeline MP rank. If using the distributed
# optimizer, then the optimizer's path must additionally include the # optimizer, then the optimizer's path must additionally include the
...@@ -98,7 +98,7 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, ...@@ -98,7 +98,7 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
if use_distributed_optimizer: if use_distributed_optimizer:
model_name = os.path.join(common_path, "model_rng.pt") model_name = os.path.join(common_path, "model_rng.pt")
optim_name = os.path.join( optim_name = os.path.join(
common_path + "_%03d" % mpu.get_data_parallel_rank(), common_path + "_%03d" % core.get_data_parallel_rank(),
"optim.pt") "optim.pt")
else: else:
model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt") model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt")
...@@ -185,18 +185,18 @@ def get_rng_state(): ...@@ -185,18 +185,18 @@ def get_rng_state():
'np_rng_state': np.random.get_state(), 'np_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(), 'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state(), 'cuda_rng_state': torch.cuda.get_rng_state(),
'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states()} 'rng_tracker_states': core.tensor_parallel.get_cuda_rng_tracker().get_states()}
rng_state_list = None rng_state_list = None
if torch.distributed.is_initialized() and \ if torch.distributed.is_initialized() and \
mpu.get_data_parallel_world_size() > 1 and \ core.get_data_parallel_world_size() > 1 and \
args.data_parallel_random_init: args.data_parallel_random_init:
rng_state_list = \ rng_state_list = \
[None for i in range(mpu.get_data_parallel_world_size())] [None for i in range(core.get_data_parallel_world_size())]
torch.distributed.all_gather_object( torch.distributed.all_gather_object(
rng_state_list, rng_state_list,
rng_state, rng_state,
group=mpu.get_data_parallel_group()) group=core.get_data_parallel_group())
else: else:
rng_state_list = [rng_state] rng_state_list = [rng_state]
...@@ -223,7 +223,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): ...@@ -223,7 +223,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
# Collect args, model, RNG. # Collect args, model, RNG.
model_state_dict = {} model_state_dict = {}
if not torch.distributed.is_initialized() \ if not torch.distributed.is_initialized() \
or mpu.get_data_parallel_rank() == 0: or core.get_data_parallel_rank() == 0:
# Arguments, iteration, and model. # Arguments, iteration, and model.
model_state_dict['args'] = args model_state_dict['args'] = args
...@@ -233,7 +233,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): ...@@ -233,7 +233,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
model_state_dict['model'] = model[0].state_dict_for_save_checkpoint() model_state_dict['model'] = model[0].state_dict_for_save_checkpoint()
else: else:
for i in range(len(model)): for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i) core.set_virtual_pipeline_model_parallel_rank(i)
model_state_dict['model%d' % i] = \ model_state_dict['model%d' % i] = \
model[i].state_dict_for_save_checkpoint() model[i].state_dict_for_save_checkpoint()
...@@ -246,7 +246,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): ...@@ -246,7 +246,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
optim_state_dict = {} optim_state_dict = {}
if not args.no_save_optim \ if not args.no_save_optim \
and (not torch.distributed.is_initialized() and (not torch.distributed.is_initialized()
or mpu.get_data_parallel_rank() == 0 or core.get_data_parallel_rank() == 0
or args.use_distributed_optimizer): or args.use_distributed_optimizer):
# Optimizer stuff. # Optimizer stuff.
...@@ -548,7 +548,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -548,7 +548,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
model[0].load_state_dict(model_state_dict['model'], strict=strict) model[0].load_state_dict(model_state_dict['model'], strict=strict)
else: else:
for i in range(len(model)): for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i) core.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(model_state_dict['model%d' % i], strict=strict) model[i].load_state_dict(model_state_dict['model%d' % i], strict=strict)
# Fix up query/key/value matrix ordering if needed # Fix up query/key/value matrix ordering if needed
...@@ -580,7 +580,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -580,7 +580,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# access rng_state for data parallel rank # access rng_state for data parallel rank
if args.data_parallel_random_init: if args.data_parallel_random_init:
rng_state = model_state_dict['rng_state'][mpu.get_data_parallel_rank()] rng_state = model_state_dict['rng_state'][core.get_data_parallel_rank()]
else: else:
rng_state = model_state_dict['rng_state'][0] rng_state = model_state_dict['rng_state'][0]
random.setstate(rng_state['random_rng_state']) random.setstate(rng_state['random_rng_state'])
...@@ -590,7 +590,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -590,7 +590,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array # Check for empty states array
if not rng_state['rng_tracker_states']: if not rng_state['rng_tracker_states']:
raise KeyError raise KeyError
mpu.get_cuda_rng_tracker().set_states( core.tensor_parallel.get_cuda_rng_tracker().set_states(
rng_state['rng_tracker_states']) rng_state['rng_tracker_states'])
else: # backward compatability else: # backward compatability
random.setstate(model_state_dict['random_rng_state']) random.setstate(model_state_dict['random_rng_state'])
...@@ -600,7 +600,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -600,7 +600,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array # Check for empty states array
if not model_state_dict['rng_tracker_states']: if not model_state_dict['rng_tracker_states']:
raise KeyError raise KeyError
mpu.get_cuda_rng_tracker().set_states( core.tensor_parallel.get_cuda_rng_tracker().set_states(
model_state_dict['rng_tracker_states']) model_state_dict['rng_tracker_states'])
except KeyError: except KeyError:
print_rank_0('Unable to load rng state from checkpoint {}. ' print_rank_0('Unable to load rng state from checkpoint {}. '
...@@ -640,7 +640,7 @@ def load_biencoder_checkpoint(model, only_query_model=False, ...@@ -640,7 +640,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
args.use_distributed_optimizer, args.use_distributed_optimizer,
release=False) release=False)
if mpu.get_data_parallel_rank() == 0: if core.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format( print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name)) torch.distributed.get_rank(), checkpoint_name))
...@@ -656,7 +656,7 @@ def load_biencoder_checkpoint(model, only_query_model=False, ...@@ -656,7 +656,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
model[0].load_state_dict(ret_state_dict) model[0].load_state_dict(ret_state_dict)
torch.distributed.barrier() torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0: if core.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name)) print(' successfully loaded {}'.format(checkpoint_name))
return model return model
from .parallel_state import ( from .parallel_state import (
initialize_model_parallel, initialize_model_parallel,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
get_pipeline_model_parallel_world_size, get_pipeline_model_parallel_world_size,
get_pipeline_model_parallel_rank,
get_virtual_pipeline_model_parallel_rank, set_virtual_pipeline_model_parallel_rank,
get_data_parallel_world_size, get_data_parallel_world_size,
get_data_parallel_rank,
get_global_memory_buffer,
get_num_layers,
) )
from megatron.core import tensor_parallel from megatron.core import tensor_parallel
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
import torch import torch
from typing import Optional from typing import Optional
from .utils import GlobalMemoryBuffer
# Intra-layer model parallel group that the current rank belongs to. # Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None _TENSOR_MODEL_PARALLEL_GROUP = None
# Inter-layer model parallel group that the current rank belongs to. # Inter-layer model parallel group that the current rank belongs to.
...@@ -42,7 +44,8 @@ _PIPELINE_GLOBAL_RANKS = None ...@@ -42,7 +44,8 @@ _PIPELINE_GLOBAL_RANKS = None
# rank when broadcasting weights from src to all other data parallel ranks # rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None _DATA_PARALLEL_GLOBAL_RANKS = None
# Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER = None
def is_unitialized(): def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization""" """Useful for code segments that may be accessed with or without mpu initialization"""
...@@ -195,6 +198,12 @@ def initialize_model_parallel( ...@@ -195,6 +198,12 @@ def initialize_model_parallel(
if rank in ranks: if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
# Initialize global memory buffer
# This isn't really "parallel state" but there isn't another good place to
# put this. If we end up with a more generic initialization of megatron-core
# we could stick it there
_set_global_memory_buffer()
def model_parallel_is_initialized(): def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized.""" """Check if model and data parallel groups are initialized."""
...@@ -506,6 +515,18 @@ def get_data_parallel_rank(): ...@@ -506,6 +515,18 @@ def get_data_parallel_rank():
"""Return my rank for the data parallel group.""" """Return my rank for the data parallel group."""
return torch.distributed.get_rank(group=get_data_parallel_group()) return torch.distributed.get_rank(group=get_data_parallel_group())
def _set_global_memory_buffer():
"""Initialize global buffer"""
global _GLOBAL_MEMORY_BUFFER
assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized'
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
def get_global_memory_buffer():
assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
return _GLOBAL_MEMORY_BUFFER
def destroy_model_parallel(): def destroy_model_parallel():
"""Set the groups to none.""" """Set the groups to none."""
global _MODEL_PARALLEL_GROUP global _MODEL_PARALLEL_GROUP
......
from .cross_entropy import vocab_parallel_cross_entropy from .cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data from .data import broadcast_data
from .layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes,
param_is_not_tensor_parallel_duplicate,
linear_with_grad_accumulation_and_async_allreduce
)
from .mappings import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
scatter_to_tensor_model_parallel_region,
scatter_to_sequence_parallel_region,
)
from .random import (
checkpoint,
get_cuda_rng_tracker,
model_parallel_cuda_manual_seed
)
from .utils import split_tensor_along_last_dim
__all__ = [ __all__ = [
# cross_entropy.py # cross_entropy.py
"vocab_parallel_cross_entropy", "vocab_parallel_cross_entropy",
# data.py # data.py
"broadcast_data", "broadcast_data",
#layers.py
"ColumnParallelLinear",
"RowParallelLinear",
"VocabParallelEmbedding",
"set_defaults_if_not_set_tensor_model_parallel_attributes",
"copy_tensor_model_parallel_attributes",
"param_is_not_tensor_parallel_duplicate",
"linear_with_grad_accumulation_and_async_allreduce",
# mappings.py
"copy_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region",
"gather_from_sequence_parallel_region",
# "reduce_from_tensor_model_parallel_region",
"scatter_to_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region",
# random.py
"checkpoint",
"get_cuda_rng_tracker",
"model_parallel_cuda_manual_seed",
# utils.py
"split_tensor_along_last_dim",
] ]
...@@ -2,7 +2,11 @@ ...@@ -2,7 +2,11 @@
import torch import torch
from .initialize import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank from megatron.core.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group,
)
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch # Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch # repo: https://github.com/pytorch/pytorch
...@@ -11,12 +10,12 @@ from torch import _C ...@@ -11,12 +10,12 @@ from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager from torch.cuda import _lazy_call, device as device_ctx_manager
from torch.utils.checkpoint import detach_variable from torch.utils.checkpoint import detach_variable
from megatron.memory import allocate_mem_buff from megatron.core.parallel_state import (
get_data_parallel_rank,
from .initialize import get_data_parallel_rank get_tensor_model_parallel_group,
from .initialize import get_tensor_model_parallel_group get_tensor_model_parallel_rank,
from .initialize import get_tensor_model_parallel_rank get_tensor_model_parallel_world_size,
from .initialize import get_tensor_model_parallel_world_size )
# Default name for the model parallel rng tracker. # Default name for the model parallel rng tracker.
...@@ -89,85 +88,6 @@ def gather_split_1d_tensor(tensor): ...@@ -89,85 +88,6 @@ def gather_split_1d_tensor(tensor):
return gathered return gathered
def _kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor.
View tensors have the undesirable side-affect of retaining a reference
to the originally-viewed tensor, even after manually setting the '.data'
field. This method creates a new tensor that links to the old tensor's
data, without linking the viewed tensor, referenced via the '._base'
field.
'''
out = torch.empty(
(1,),
dtype = inp.dtype,
device = inp.device,
requires_grad = requires_grad,
)
out.data = inp.data
return out
class MakeViewlessTensor(torch.autograd.Function):
'''
Autograd function to make a viewless tensor.
This function should be used in cases where the computation graph needs
to be propagated, but we only want a viewless tensor (e.g.,
ParallelTransformer's hidden_states). Call this function by passing
'keep_graph = True' to 'make_viewless_tensor()'.
'''
@staticmethod
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def make_viewless_tensor(inp, requires_grad, keep_graph):
'''
Entry-point for creating viewless tensors.
This method should be used, rather than calling 'MakeViewlessTensor'
or '_kernel_make_viewless_tensor' directly. This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor.
'''
# return tensor as-is, if not a 'view'
if inp._base is None:
return inp
# create viewless tensor
if keep_graph:
return MakeViewlessTensor.apply(inp, requires_grad)
else:
return _kernel_make_viewless_tensor(inp, requires_grad)
def assert_viewless_tensor(tensor, extra_msg = None):
'''Assert that a tensor is not a view (i.e., its '._base' field is
not set).'''
if isinstance(tensor, list):
[ assert_viewless_tensor(t) for t in tensor ]
return tensor
if not isinstance(tensor, torch.Tensor):
return tensor
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). %s"
) % extra_msg
return tensor
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
'''Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
'''
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
tensor.data = new_data_tensor
class CudaRNGStatesTracker: class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states. """Tracker for the cuda RNG states.
......
...@@ -5,19 +5,6 @@ from typing import List, Sequence ...@@ -5,19 +5,6 @@ from typing import List, Sequence
from megatron.core.utils import divide from megatron.core.utils import divide
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, '{} is not divisible by {}'.format(
numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim( def split_tensor_along_last_dim(
tensor: torch.Tensor, tensor: torch.Tensor,
num_partitions: int, num_partitions: int,
......
"""Utility functions used through Megatron core""" """Utility functions used throughout Megatron core"""
from functools import reduce
import operator
import torch import torch
from megatron.core import parallel_state from megatron.core import parallel_state
...@@ -46,3 +49,101 @@ def gather_split_1d_tensor(tensor): ...@@ -46,3 +49,101 @@ def gather_split_1d_tensor(tensor):
group=parallel_state.get_tensor_model_parallel_group() group=parallel_state.get_tensor_model_parallel_group()
) )
return gathered return gathered
class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
are not used concurrently."""
def __init__(self):
self.buffer = {}
def get_tensor(self, tensor_shape, dtype, name):
required_len = reduce(operator.mul, tensor_shape, 1)
if self.buffer.get((name, dtype), None) is None or \
self.buffer[(name, dtype)].numel() < required_len:
self.buffer[(name, dtype)] = \
torch.empty(required_len,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False)
return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
def _kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor.
View tensors have the undesirable side-affect of retaining a reference
to the originally-viewed tensor, even after manually setting the '.data'
field. This method creates a new tensor that links to the old tensor's
data, without linking the viewed tensor, referenced via the '._base'
field.
'''
out = torch.empty(
(1,),
dtype = inp.dtype,
device = inp.device,
requires_grad = requires_grad,
)
out.data = inp.data
return out
class MakeViewlessTensor(torch.autograd.Function):
'''
Autograd function to make a viewless tensor.
This function should be used in cases where the computation graph needs
to be propagated, but we only want a viewless tensor (e.g.,
ParallelTransformer's hidden_states). Call this function by passing
'keep_graph = True' to 'make_viewless_tensor()'.
'''
@staticmethod
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def make_viewless_tensor(inp, requires_grad, keep_graph):
'''
Entry-point for creating viewless tensors.
This method should be used, rather than calling 'MakeViewlessTensor'
or '_kernel_make_viewless_tensor' directly. This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor.
'''
# return tensor as-is, if not a 'view'
if inp._base is None:
return inp
# create viewless tensor
if keep_graph:
return MakeViewlessTensor.apply(inp, requires_grad)
else:
return _kernel_make_viewless_tensor(inp, requires_grad)
def assert_viewless_tensor(tensor, extra_msg = None):
'''Assert that a tensor is not a view (i.e., its '._base' field is
not set).'''
if isinstance(tensor, list):
[ assert_viewless_tensor(t) for t in tensor ]
return tensor
if not isinstance(tensor, torch.Tensor):
return tensor
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). %s"
) % extra_msg
return tensor
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
'''Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
'''
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
tensor.data = new_data_tensor
...@@ -4,8 +4,6 @@ ...@@ -4,8 +4,6 @@
import os import os
import sys import sys
from functools import reduce
import operator
import torch import torch
from megatron import dist_signal_handler from megatron import dist_signal_handler
...@@ -20,7 +18,6 @@ _GLOBAL_TENSORBOARD_WRITER = None ...@@ -20,7 +18,6 @@ _GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_ADLR_AUTORESUME = None _GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None _GLOBAL_TIMERS = None
_GLOBAL_SIGNAL_HANDLER = None _GLOBAL_SIGNAL_HANDLER = None
_GLOBAL_MEMORY_BUFFER = None
def get_args(): def get_args():
"""Return arguments.""" """Return arguments."""
...@@ -70,11 +67,6 @@ def get_signal_handler(): ...@@ -70,11 +67,6 @@ def get_signal_handler():
return _GLOBAL_SIGNAL_HANDLER return _GLOBAL_SIGNAL_HANDLER
def get_global_memory_buffer():
_ensure_var_is_initialized(_GLOBAL_MEMORY_BUFFER, 'global memory buffer')
return _GLOBAL_MEMORY_BUFFER
def _set_signal_handler(): def _set_signal_handler():
global _GLOBAL_SIGNAL_HANDLER global _GLOBAL_SIGNAL_HANDLER
_ensure_var_is_not_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler') _ensure_var_is_not_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler')
...@@ -96,7 +88,6 @@ def set_global_variables(args): ...@@ -96,7 +88,6 @@ def set_global_variables(args):
_set_tensorboard_writer(args) _set_tensorboard_writer(args)
_set_adlr_autoresume(args) _set_adlr_autoresume(args)
_set_timers(args) _set_timers(args)
_set_global_memory_buffer()
if args.exit_signal_handler: if args.exit_signal_handler:
_set_signal_handler() _set_signal_handler()
...@@ -176,13 +167,6 @@ def _set_timers(args): ...@@ -176,13 +167,6 @@ def _set_timers(args):
_GLOBAL_TIMERS = Timers(args.timing_log_level, args.timing_log_option) _GLOBAL_TIMERS = Timers(args.timing_log_level, args.timing_log_option)
def _set_global_memory_buffer():
"""Initialize global buffer"""
global _GLOBAL_MEMORY_BUFFER
_ensure_var_is_not_initialized(_GLOBAL_MEMORY_BUFFER, 'global memory buffer')
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
def _ensure_var_is_initialized(var, name): def _ensure_var_is_initialized(var, name):
"""Make sure the input variable is not None.""" """Make sure the input variable is not None."""
assert var is not None, '{} is not initialized.'.format(name) assert var is not None, '{} is not initialized.'.format(name)
...@@ -194,22 +178,3 @@ def _ensure_var_is_not_initialized(var, name): ...@@ -194,22 +178,3 @@ def _ensure_var_is_not_initialized(var, name):
class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
are not used concurrently."""
def __init__(self):
self.buffer = {}
def get_tensor(self, tensor_shape, dtype, name):
required_len = reduce(operator.mul, tensor_shape, 1)
if self.buffer.get((name, dtype), None) is None or \
self.buffer[(name, dtype)].numel() < required_len:
self.buffer[(name, dtype)] = \
torch.empty(required_len,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False)
return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
...@@ -219,7 +219,7 @@ def _set_random_seed(seed_, data_parallel_random_init=False): ...@@ -219,7 +219,7 @@ def _set_random_seed(seed_, data_parallel_random_init=False):
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.device_count() > 0: if torch.cuda.device_count() > 0:
mpu.model_parallel_cuda_manual_seed(seed) core.tensor_parallel.model_parallel_cuda_manual_seed(seed)
else: else:
raise ValueError('Seed ({}) should be a positive integer.'.format(seed)) raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
......
...@@ -10,7 +10,7 @@ from torch.nn.parameter import Parameter ...@@ -10,7 +10,7 @@ from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
import importlib import importlib
from megatron.mpu import make_viewless_tensor from megatron.core.utils import make_viewless_tensor
try: try:
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import core
from .module import MegatronModule from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType from megatron.model.enums import LayerType, AttnMaskType
from megatron.model.transformer import ParallelTransformer from megatron.model.transformer import ParallelTransformer
...@@ -22,24 +22,27 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -22,24 +22,27 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
if args.async_tensor_model_parallel_allreduce or\ if args.async_tensor_model_parallel_allreduce or\
args.sequence_parallel: args.sequence_parallel:
input_parallel = input_ input_parallel = input_
model_parallel = mpu.get_tensor_model_parallel_world_size() > 1 model_parallel = core.get_tensor_model_parallel_world_size() > 1
async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \ async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
model_parallel and not args.sequence_parallel model_parallel and not args.sequence_parallel
else: else:
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_) input_parallel = core.tensor_parallel.copy_to_tensor_model_parallel_region(input_)
async_grad_allreduce = False async_grad_allreduce = False
# Matrix multiply. # Matrix multiply.
logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply( logits_parallel = core.tensor_parallel.linear_with_grad_accumulation_and_async_allreduce(
input_parallel, word_embeddings_weight, bias, input=input_parallel,
args.gradient_accumulation_fusion, weight=word_embeddings_weight,
async_grad_allreduce, args.sequence_parallel) bias=bias,
gradient_accumulation_fusion=args.gradient_accumulation_fusion,
async_grad_allreduce=async_grad_allreduce,
sequence_parallel_enabled=args.sequence_parallel)
# Gather if needed. # Gather if needed.
if parallel_output: if parallel_output:
return logits_parallel return logits_parallel
return mpu.gather_from_tensor_model_parallel_region(logits_parallel) return core.tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(num_tokentypes, add_pooler, def get_language_model(num_tokentypes, add_pooler,
...@@ -103,7 +106,7 @@ class Pooler(MegatronModule): ...@@ -103,7 +106,7 @@ class Pooler(MegatronModule):
# gather data along sequence dimensions # gather data along sequence dimensions
# same pooler is run on all tensor parallel nodes # same pooler is run on all tensor parallel nodes
if self.sequence_parallel: if self.sequence_parallel:
hidden_states = mpu.gather_from_sequence_parallel_region( hidden_states = core.tensor_parallel.gather_from_sequence_parallel_region(
hidden_states, hidden_states,
tensor_parallel_output_grad=False) tensor_parallel_output_grad=False)
...@@ -143,9 +146,13 @@ class Embedding(MegatronModule): ...@@ -143,9 +146,13 @@ class Embedding(MegatronModule):
args = get_args() args = get_args()
# Word embeddings (parallel). # Word embeddings (parallel).
self.word_embeddings = mpu.VocabParallelEmbedding( self.word_embeddings = core.tensor_parallel.VocabParallelEmbedding(
vocab_size, self.hidden_size, vocab_size, self.hidden_size,
init_method=self.init_method) init_method=self.init_method,
params_dtype=args.params_dtype,
use_cpu_initialization=args.use_cpu_initialization,
perform_initialization=args.perform_initialization
)
self._word_embeddings_key = 'word_embeddings' self._word_embeddings_key = 'word_embeddings'
# Position embedding (serial). # Position embedding (serial).
...@@ -222,8 +229,8 @@ class Embedding(MegatronModule): ...@@ -222,8 +229,8 @@ class Embedding(MegatronModule):
# Dropout. # Dropout.
if self.sequence_parallel: if self.sequence_parallel:
embeddings = mpu.scatter_to_sequence_parallel_region(embeddings) embeddings = core.tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
with mpu.get_cuda_rng_tracker().fork(): with core.tensor_parallel.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings) embeddings = self.embedding_dropout(embeddings)
else: else:
embeddings = self.embedding_dropout(embeddings) embeddings = self.embedding_dropout(embeddings)
......
...@@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter ...@@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron import core
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) _FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
...@@ -76,9 +77,12 @@ class MegatronModule(torch.nn.Module): ...@@ -76,9 +77,12 @@ class MegatronModule(torch.nn.Module):
self._word_embeddings_for_head_key = 'word_embeddings_for_head' self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first # set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below. # stage's weights using all_reduce below.
self.word_embeddings = mpu.VocabParallelEmbedding( self.word_embeddings = core.tensor_parallel.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size, args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std)) init_method=init_method_normal(args.init_method_std),
params_dtype=args.params_dtype,
use_cpu_initialization=args.use_cpu_initialization,
perform_initialization=args.perform_initialization)
self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True self.word_embeddings.weight.shared = True
......
...@@ -6,8 +6,9 @@ from contextlib import nullcontext ...@@ -6,8 +6,9 @@ from contextlib import nullcontext
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_timers, get_args, get_global_memory_buffer from megatron import get_timers, get_args
from megatron import mpu from megatron.core import get_global_memory_buffer
from megatron import core
from .module import MegatronModule from .module import MegatronModule
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
from megatron.model import LayerNorm from megatron.model import LayerNorm
...@@ -32,7 +33,7 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu ...@@ -32,7 +33,7 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
""" """
class DropPath(MegatronModule): class DropPath(MegatronModule):
"""Drop paths (Stochastic Depth) per sample """Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks). (when applied in main path of residual blocks).
""" """
...@@ -52,6 +53,17 @@ class DropPath(MegatronModule): ...@@ -52,6 +53,17 @@ class DropPath(MegatronModule):
output = hidden_state.div(keep_prob) * random_tensor output = hidden_state.div(keep_prob) * random_tensor
return output return output
def _args_to_kwargs():
args = get_args()
common_kwargs = {
"params_dtype": args.params_dtype,
"use_cpu_initialization": args.use_cpu_initialization,
"perform_initialization": args.perform_initialization,
"gradient_accumulation_fusion": args.gradient_accumulation_fusion,
"sequence_parallel_enabled": args.sequence_parallel,
}
return common_kwargs
class ParallelMLP(MegatronModule): class ParallelMLP(MegatronModule):
"""MLP. """MLP.
...@@ -65,13 +77,16 @@ class ParallelMLP(MegatronModule): ...@@ -65,13 +77,16 @@ class ParallelMLP(MegatronModule):
super(ParallelMLP, self).__init__() super(ParallelMLP, self).__init__()
args = get_args() args = get_args()
# Project to 4h. # Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear( self.dense_h_to_4h = core.tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
args.ffn_hidden_size, args.ffn_hidden_size,
gather_output=False, gather_output=False,
init_method=init_method, init_method=init_method,
skip_bias_add=True) skip_bias_add=True,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
self.bias_gelu_fusion = args.bias_gelu_fusion self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu self.activation_func = F.gelu
...@@ -81,12 +96,13 @@ class ParallelMLP(MegatronModule): ...@@ -81,12 +96,13 @@ class ParallelMLP(MegatronModule):
self.activation_func = erf_gelu self.activation_func = erf_gelu
# Project back to h. # Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear( self.dense_4h_to_h = core.tensor_parallel.RowParallelLinear(
args.ffn_hidden_size, args.ffn_hidden_size,
args.hidden_size, args.hidden_size,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True,
**_args_to_kwargs())
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -136,7 +152,7 @@ class SwitchMLP(MegatronModule): ...@@ -136,7 +152,7 @@ class SwitchMLP(MegatronModule):
output_total = torch.empty_like(hidden_states) output_total = torch.empty_like(hidden_states)
output_bias_total = torch.empty_like(hidden_states) output_bias_total = torch.empty_like(hidden_states)
#TODO (rprenger) This does each expert in serial, but it could be parallelized #TODO (rprenger) This does each expert in serial, but it could be parallelized
for expert_num, expert in enumerate(self.experts): for expert_num, expert in enumerate(self.experts):
local_indices = (max_ind == expert_num).nonzero() local_indices = (max_ind == expert_num).nonzero()
hidden = hidden_states[local_indices,:] hidden = hidden_states[local_indices,:]
...@@ -173,12 +189,12 @@ class CoreAttention(MegatronModule): ...@@ -173,12 +189,12 @@ class CoreAttention(MegatronModule):
projection_size = args.kv_channels * args.num_attention_heads projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values. # Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size() world_size = core.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(projection_size, self.hidden_size_per_partition = core.utils.divide(projection_size,
world_size) world_size)
self.hidden_size_per_attention_head = mpu.divide( self.hidden_size_per_attention_head = core.utils.divide(
projection_size, args.num_attention_heads) projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide( self.num_attention_heads_per_partition = core.utils.divide(
args.num_attention_heads, world_size) args.num_attention_heads, world_size)
coeff = None coeff = None
...@@ -247,7 +263,7 @@ class CoreAttention(MegatronModule): ...@@ -247,7 +263,7 @@ class CoreAttention(MegatronModule):
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
if not self.sequence_parallel: if not self.sequence_parallel:
with mpu.get_cuda_rng_tracker().fork(): with core.tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs) attention_probs = self.attention_dropout(attention_probs)
else: else:
attention_probs = self.attention_dropout(attention_probs) attention_probs = self.attention_dropout(attention_probs)
...@@ -311,44 +327,52 @@ class ParallelAttention(MegatronModule): ...@@ -311,44 +327,52 @@ class ParallelAttention(MegatronModule):
projection_size = args.kv_channels * args.num_attention_heads projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values. # Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size() world_size = core.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = mpu.divide( self.hidden_size_per_attention_head = core.utils.divide(
projection_size, args.num_attention_heads) projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide( self.num_attention_heads_per_partition = core.utils.divide(
args.num_attention_heads, world_size) args.num_attention_heads, world_size)
# Strided linear layer. # Strided linear layer.
if attention_type == AttnType.self_attn: if attention_type == AttnType.self_attn:
self.query_key_value = mpu.ColumnParallelLinear( self.query_key_value = core.tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
3 * projection_size, 3 * projection_size,
gather_output=False, gather_output=False,
init_method=init_method) init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
else: else:
assert attention_type == AttnType.cross_attn assert attention_type == AttnType.cross_attn
self.query = mpu.ColumnParallelLinear( self.query = core.tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
projection_size, projection_size,
gather_output=False, gather_output=False,
init_method=init_method) init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
self.key_value = mpu.ColumnParallelLinear(
self.key_value = core.tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
2 * projection_size, 2 * projection_size,
gather_output=False, gather_output=False,
init_method=init_method) init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
self.core_attention = CoreAttention(self.layer_number, self.core_attention = CoreAttention(self.layer_number,
self.attn_mask_type) self.attn_mask_type)
self.checkpoint_core_attention = args.recompute_granularity == 'selective' self.checkpoint_core_attention = args.recompute_granularity == 'selective'
# Output. # Output.
self.dense = mpu.RowParallelLinear( self.dense = core.tensor_parallel.RowParallelLinear(
projection_size, projection_size,
args.hidden_size, args.hidden_size,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True,
**_args_to_kwargs())
def _checkpointed_attention_forward(self, query_layer, key_layer, def _checkpointed_attention_forward(self, query_layer, key_layer,
value_layer, attention_mask): value_layer, attention_mask):
...@@ -362,7 +386,7 @@ class ParallelAttention(MegatronModule): ...@@ -362,7 +386,7 @@ class ParallelAttention(MegatronModule):
value_layer, attention_mask) value_layer, attention_mask)
return output_ return output_
hidden_states = mpu.checkpoint( hidden_states = core.tensor_parallel.checkpoint(
custom_forward, custom_forward,
False, query_layer, key_layer, value_layer, attention_mask) False, query_layer, key_layer, value_layer, attention_mask)
...@@ -415,7 +439,7 @@ class ParallelAttention(MegatronModule): ...@@ -415,7 +439,7 @@ class ParallelAttention(MegatronModule):
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, (query_layer,
key_layer, key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) value_layer) = core.tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
else: else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output) mixed_kv_layer, _ = self.key_value(encoder_output)
...@@ -428,7 +452,7 @@ class ParallelAttention(MegatronModule): ...@@ -428,7 +452,7 @@ class ParallelAttention(MegatronModule):
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer, (key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2) value_layer) = core.tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp] # Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states) query_layer, _ = self.query(hidden_states)
...@@ -674,9 +698,9 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -674,9 +698,9 @@ class ParallelTransformerLayer(MegatronModule):
# won't result in memory savings (like the data loader, or # won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this # p2p_communication), it serves to document the origin of this
# 'view' tensor. # 'view' tensor.
output = mpu.make_viewless_tensor(inp = output, output = core.utils.make_viewless_tensor(inp = output,
requires_grad = output.requires_grad, requires_grad = output.requires_grad,
keep_graph = True) keep_graph = True)
else: else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias, out = torch.nn.functional.dropout(mlp_output + mlp_bias,
...@@ -719,7 +743,7 @@ class ParallelTransformer(MegatronModule): ...@@ -719,7 +743,7 @@ class ParallelTransformer(MegatronModule):
def __init__(self, init_method, output_layer_init_method, def __init__(self, init_method, output_layer_init_method,
layer_type=LayerType.encoder, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding, self_attn_mask_type=AttnMaskType.padding,
post_layer_norm=True, post_layer_norm=True,
pre_process=True, post_process=True, pre_process=True, post_process=True,
drop_path_rate=0.0): drop_path_rate=0.0):
super(ParallelTransformer, self).__init__() super(ParallelTransformer, self).__init__()
...@@ -745,7 +769,7 @@ class ParallelTransformer(MegatronModule): ...@@ -745,7 +769,7 @@ class ParallelTransformer(MegatronModule):
self.sequence_parallel = args.sequence_parallel self.sequence_parallel = args.sequence_parallel
# Number of layers. # Number of layers.
self.num_layers = mpu.get_num_layers( self.num_layers = core.get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder) args, args.model_type == ModelType.encoder_and_decoder)
self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)] self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
...@@ -775,21 +799,21 @@ class ParallelTransformer(MegatronModule): ...@@ -775,21 +799,21 @@ class ParallelTransformer(MegatronModule):
# layers to stages like (each list is a model chunk): # layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5] # Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7] # Stage 1: [2, 3] [6, 7]
offset = mpu.get_virtual_pipeline_model_parallel_rank() * ( offset = core.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size) + \ args.num_layers // args.virtual_pipeline_model_parallel_size) + \
(mpu.get_pipeline_model_parallel_rank() * self.num_layers) (core.get_pipeline_model_parallel_rank() * self.num_layers)
else: else:
# Each stage gets a contiguous set of layers. # Each stage gets a contiguous set of layers.
if args.model_type == ModelType.encoder_and_decoder and \ if args.model_type == ModelType.encoder_and_decoder and \
mpu.get_pipeline_model_parallel_world_size() > 1: core.get_pipeline_model_parallel_world_size() > 1:
pipeline_rank = mpu.get_pipeline_model_parallel_rank() pipeline_rank = core.get_pipeline_model_parallel_rank()
if layer_type == LayerType.encoder: if layer_type == LayerType.encoder:
offset = pipeline_rank * self.num_layers offset = pipeline_rank * self.num_layers
else: else:
num_ranks_in_enc = args.pipeline_model_parallel_split_rank num_ranks_in_enc = args.pipeline_model_parallel_split_rank
offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
else: else:
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers offset = core.get_pipeline_model_parallel_rank() * self.num_layers
if self.num_layers == 0: if self.num_layers == 0:
# When a standalone embedding stage is used (e.g., # When a standalone embedding stage is used (e.g.,
...@@ -838,7 +862,7 @@ class ParallelTransformer(MegatronModule): ...@@ -838,7 +862,7 @@ class ParallelTransformer(MegatronModule):
# A method to further reduce memory usage reducing checkpoints. # A method to further reduce memory usage reducing checkpoints.
l = 0 l = 0
while l < self.num_layers: while l < self.num_layers:
hidden_states = mpu.checkpoint( hidden_states = core.tensor_parallel.checkpoint(
custom(l, l + self.recompute_num_layers), custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations, self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
...@@ -850,7 +874,7 @@ class ParallelTransformer(MegatronModule): ...@@ -850,7 +874,7 @@ class ParallelTransformer(MegatronModule):
# A method fully use the device memory removing redundant re-computation. # A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers): for l in range(self.num_layers):
if l < self.recompute_num_layers: if l < self.recompute_num_layers:
hidden_states = mpu.checkpoint( hidden_states = core.tensor_parallel.checkpoint(
custom(l, l + 1), custom(l, l + 1),
self.distribute_saved_activations, self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
...@@ -896,19 +920,19 @@ class ParallelTransformer(MegatronModule): ...@@ -896,19 +920,19 @@ class ParallelTransformer(MegatronModule):
# However, we don't explicitly check mbs == 1 here because # However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input # make_viewless_tensor() has negligible overhead when its input
# is already viewless. # is already viewless.
# #
# - For the 'else' case above, calling make_viewless_tensor() here is # - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator) # likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor() # already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof. # is called here to be future-proof and corner-case-proof.
hidden_states = mpu.make_viewless_tensor( hidden_states = core.utils.make_viewless_tensor(
hidden_states, hidden_states,
requires_grad=True, requires_grad=True,
keep_graph=True, keep_graph=True,
) )
if self.sequence_parallel: if self.sequence_parallel:
rng_context = mpu.get_cuda_rng_tracker().fork() rng_context = core.tensor_parallel.get_cuda_rng_tracker().fork()
else: else:
rng_context = nullcontext() rng_context = nullcontext()
......
...@@ -32,30 +32,6 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi ...@@ -32,30 +32,6 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi
from .initialize import initialize_model_parallel from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized from .initialize import model_parallel_is_initialized
from .layers import LinearWithGradAccumulationAndAsyncCommunication
from .layers import ColumnParallelLinear
from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding
from .layers import (set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes)
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import scatter_to_sequence_parallel_region
from .mappings import gather_from_sequence_parallel_region
from .mappings import reduce_scatter_to_sequence_parallel_region
from .random import checkpoint
from .random import get_cuda_rng_tracker
from .random import model_parallel_cuda_manual_seed
from .random import gather_split_1d_tensor
from .random import split_tensor_into_1d_equal_chunks
from .random import make_viewless_tensor
from .random import assert_viewless_tensor
from .random import safely_set_viewless_tensor_data
from .utils import divide from .utils import divide
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
...@@ -9,7 +9,7 @@ from apex.multi_tensor_apply import multi_tensor_applier ...@@ -9,7 +9,7 @@ from apex.multi_tensor_apply import multi_tensor_applier
import amp_C import amp_C
from megatron.model.module import param_is_not_shared from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
def clip_grad_norm_fp32(parameters, grads_for_norm, def clip_grad_norm_fp32(parameters, grads_for_norm,
......
...@@ -11,7 +11,7 @@ from megatron import get_timers ...@@ -11,7 +11,7 @@ from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.model.module import param_is_not_shared from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
......
...@@ -12,11 +12,12 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors ...@@ -12,11 +12,12 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron import core
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module from megatron.model import Float16Module
from megatron.model.module import param_is_not_shared from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
from megatron.utils import unwrap_model from megatron.utils import unwrap_model
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32 from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
...@@ -527,8 +528,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): ...@@ -527,8 +528,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
# Create a copy # Create a copy
main_param = param.detach().clone().float() main_param = param.detach().clone().float()
# Copy tensor model parallel attributes. # Copy tensor model parallel attributes.
mpu.copy_tensor_model_parallel_attributes(main_param, core.tensor_parallel.copy_tensor_model_parallel_attributes(main_param,
param) param)
if hasattr(param, 'shared'): if hasattr(param, 'shared'):
main_param.shared = param.shared main_param.shared = param.shared
# Replace the optimizer params with the new fp32 copy. # Replace the optimizer params with the new fp32 copy.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment