Commit 5942af97 authored by Jared Casper's avatar Jared Casper
Browse files

Alias core.parallel_state as mpu and use it throughout code. RIP mpu.

parent c2ea914f
...@@ -9,8 +9,8 @@ import numpy as np ...@@ -9,8 +9,8 @@ import numpy as np
import torch import torch
from megatron import (core, from megatron import update_num_microbatches
update_num_microbatches) from megatron.core import mpu, tensor_parallel
from .global_vars import get_args from .global_vars import get_args
from .utils import (unwrap_model, from .utils import (unwrap_model,
print_rank_0) print_rank_0)
...@@ -79,11 +79,11 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, ...@@ -79,11 +79,11 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
# Use both the tensor and pipeline MP rank. # Use both the tensor and pipeline MP rank.
if pipeline_parallel is None: if pipeline_parallel is None:
pipeline_parallel = (core.get_pipeline_model_parallel_world_size() > 1) pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1)
if tensor_rank is None: if tensor_rank is None:
tensor_rank = core.get_tensor_model_parallel_rank() tensor_rank = mpu.get_tensor_model_parallel_rank()
if pipeline_rank is None: if pipeline_rank is None:
pipeline_rank = core.get_pipeline_model_parallel_rank() pipeline_rank = mpu.get_pipeline_model_parallel_rank()
# Use both the tensor and pipeline MP rank. If using the distributed # Use both the tensor and pipeline MP rank. If using the distributed
# optimizer, then the optimizer's path must additionally include the # optimizer, then the optimizer's path must additionally include the
...@@ -98,7 +98,7 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, ...@@ -98,7 +98,7 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
if use_distributed_optimizer: if use_distributed_optimizer:
model_name = os.path.join(common_path, "model_rng.pt") model_name = os.path.join(common_path, "model_rng.pt")
optim_name = os.path.join( optim_name = os.path.join(
common_path + "_%03d" % core.get_data_parallel_rank(), common_path + "_%03d" % mpu.get_data_parallel_rank(),
"optim.pt") "optim.pt")
else: else:
model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt") model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt")
...@@ -185,18 +185,18 @@ def get_rng_state(): ...@@ -185,18 +185,18 @@ def get_rng_state():
'np_rng_state': np.random.get_state(), 'np_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(), 'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state(), 'cuda_rng_state': torch.cuda.get_rng_state(),
'rng_tracker_states': core.tensor_parallel.get_cuda_rng_tracker().get_states()} 'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()}
rng_state_list = None rng_state_list = None
if torch.distributed.is_initialized() and \ if torch.distributed.is_initialized() and \
core.get_data_parallel_world_size() > 1 and \ mpu.get_data_parallel_world_size() > 1 and \
args.data_parallel_random_init: args.data_parallel_random_init:
rng_state_list = \ rng_state_list = \
[None for i in range(core.get_data_parallel_world_size())] [None for i in range(mpu.get_data_parallel_world_size())]
torch.distributed.all_gather_object( torch.distributed.all_gather_object(
rng_state_list, rng_state_list,
rng_state, rng_state,
group=core.get_data_parallel_group()) group=mpu.get_data_parallel_group())
else: else:
rng_state_list = [rng_state] rng_state_list = [rng_state]
...@@ -223,7 +223,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): ...@@ -223,7 +223,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
# Collect args, model, RNG. # Collect args, model, RNG.
model_state_dict = {} model_state_dict = {}
if not torch.distributed.is_initialized() \ if not torch.distributed.is_initialized() \
or core.get_data_parallel_rank() == 0: or mpu.get_data_parallel_rank() == 0:
# Arguments, iteration, and model. # Arguments, iteration, and model.
model_state_dict['args'] = args model_state_dict['args'] = args
...@@ -233,7 +233,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): ...@@ -233,7 +233,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
model_state_dict['model'] = model[0].state_dict_for_save_checkpoint() model_state_dict['model'] = model[0].state_dict_for_save_checkpoint()
else: else:
for i in range(len(model)): for i in range(len(model)):
core.set_virtual_pipeline_model_parallel_rank(i) mpu.set_virtual_pipeline_model_parallel_rank(i)
model_state_dict['model%d' % i] = \ model_state_dict['model%d' % i] = \
model[i].state_dict_for_save_checkpoint() model[i].state_dict_for_save_checkpoint()
...@@ -246,7 +246,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): ...@@ -246,7 +246,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
optim_state_dict = {} optim_state_dict = {}
if not args.no_save_optim \ if not args.no_save_optim \
and (not torch.distributed.is_initialized() and (not torch.distributed.is_initialized()
or core.get_data_parallel_rank() == 0 or mpu.get_data_parallel_rank() == 0
or args.use_distributed_optimizer): or args.use_distributed_optimizer):
# Optimizer stuff. # Optimizer stuff.
...@@ -548,7 +548,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -548,7 +548,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
model[0].load_state_dict(model_state_dict['model'], strict=strict) model[0].load_state_dict(model_state_dict['model'], strict=strict)
else: else:
for i in range(len(model)): for i in range(len(model)):
core.set_virtual_pipeline_model_parallel_rank(i) mpu.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(model_state_dict['model%d' % i], strict=strict) model[i].load_state_dict(model_state_dict['model%d' % i], strict=strict)
# Fix up query/key/value matrix ordering if needed # Fix up query/key/value matrix ordering if needed
...@@ -580,7 +580,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -580,7 +580,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# access rng_state for data parallel rank # access rng_state for data parallel rank
if args.data_parallel_random_init: if args.data_parallel_random_init:
rng_state = model_state_dict['rng_state'][core.get_data_parallel_rank()] rng_state = model_state_dict['rng_state'][mpu.get_data_parallel_rank()]
else: else:
rng_state = model_state_dict['rng_state'][0] rng_state = model_state_dict['rng_state'][0]
random.setstate(rng_state['random_rng_state']) random.setstate(rng_state['random_rng_state'])
...@@ -590,7 +590,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -590,7 +590,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array # Check for empty states array
if not rng_state['rng_tracker_states']: if not rng_state['rng_tracker_states']:
raise KeyError raise KeyError
core.tensor_parallel.get_cuda_rng_tracker().set_states( tensor_parallel.get_cuda_rng_tracker().set_states(
rng_state['rng_tracker_states']) rng_state['rng_tracker_states'])
else: # backward compatability else: # backward compatability
random.setstate(model_state_dict['random_rng_state']) random.setstate(model_state_dict['random_rng_state'])
...@@ -600,7 +600,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -600,7 +600,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array # Check for empty states array
if not model_state_dict['rng_tracker_states']: if not model_state_dict['rng_tracker_states']:
raise KeyError raise KeyError
core.tensor_parallel.get_cuda_rng_tracker().set_states( tensor_parallel.get_cuda_rng_tracker().set_states(
model_state_dict['rng_tracker_states']) model_state_dict['rng_tracker_states'])
except KeyError: except KeyError:
print_rank_0('Unable to load rng state from checkpoint {}. ' print_rank_0('Unable to load rng state from checkpoint {}. '
...@@ -640,7 +640,7 @@ def load_biencoder_checkpoint(model, only_query_model=False, ...@@ -640,7 +640,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
args.use_distributed_optimizer, args.use_distributed_optimizer,
release=False) release=False)
if core.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format( print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name)) torch.distributed.get_rank(), checkpoint_name))
...@@ -656,7 +656,7 @@ def load_biencoder_checkpoint(model, only_query_model=False, ...@@ -656,7 +656,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
model[0].load_state_dict(ret_state_dict) model[0].load_state_dict(ret_state_dict)
torch.distributed.barrier() torch.distributed.barrier()
if core.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name)) print(' successfully loaded {}'.format(checkpoint_name))
return model return model
from .parallel_state import ( import megatron.core.parallel_state
initialize_model_parallel, import megatron.core.tensor_parallel
get_tensor_model_parallel_world_size, import megatron.core.utils
get_tensor_model_parallel_rank,
get_pipeline_model_parallel_world_size, # Alias parallel_state as mpu, its legacy name
get_pipeline_model_parallel_rank, mpu = parallel_state
get_virtual_pipeline_model_parallel_rank, set_virtual_pipeline_model_parallel_rank,
get_data_parallel_world_size, __all__ = [
get_data_parallel_rank, "parallel_state",
get_global_memory_buffer, "tensor_parallel",
get_num_layers, "utils",
) ]
from megatron.core import tensor_parallel
...@@ -47,10 +47,6 @@ _DATA_PARALLEL_GLOBAL_RANKS = None ...@@ -47,10 +47,6 @@ _DATA_PARALLEL_GLOBAL_RANKS = None
# Memory buffers to avoid dynamic memory allocation # Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER = None _GLOBAL_MEMORY_BUFFER = None
def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return _DATA_PARALLEL_GROUP is None
def initialize_model_parallel( def initialize_model_parallel(
tensor_model_parallel_size: int = 1, tensor_model_parallel_size: int = 1,
......
...@@ -5,6 +5,7 @@ from .layers import ( ...@@ -5,6 +5,7 @@ from .layers import (
ColumnParallelLinear, ColumnParallelLinear,
RowParallelLinear, RowParallelLinear,
VocabParallelEmbedding, VocabParallelEmbedding,
set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes, set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes, copy_tensor_model_parallel_attributes,
param_is_not_tensor_parallel_duplicate, param_is_not_tensor_parallel_duplicate,
...@@ -23,10 +24,14 @@ from .mappings import ( ...@@ -23,10 +24,14 @@ from .mappings import (
from .random import ( from .random import (
checkpoint, checkpoint,
get_cuda_rng_tracker, get_cuda_rng_tracker,
model_parallel_cuda_manual_seed model_parallel_cuda_manual_seed,
) )
from .utils import split_tensor_along_last_dim from .utils import (
split_tensor_along_last_dim,
split_tensor_into_1d_equal_chunks,
gather_split_1d_tensor,
)
__all__ = [ __all__ = [
# cross_entropy.py # cross_entropy.py
......
...@@ -17,6 +17,10 @@ from megatron.core.parallel_state import ( ...@@ -17,6 +17,10 @@ from megatron.core.parallel_state import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from .utils import (
split_tensor_into_1d_equal_chunks,
gather_split_1d_tensor,
)
# Default name for the model parallel rng tracker. # Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
...@@ -55,38 +59,6 @@ def _set_cuda_rng_state(new_state, device=-1): ...@@ -55,38 +59,6 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call(cb) _lazy_call(cb)
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
"""Break a tensor into equal 1D chunks."""
partition_size = torch.numel(tensor) // \
get_tensor_model_parallel_world_size()
start_index = partition_size * get_tensor_model_parallel_rank()
end_index = start_index + partition_size
if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data
def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
numel_gathered = torch.numel(tensor) * \
get_tensor_model_parallel_world_size()
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# TODO: This API is experimental in pytorch (as of Feb 2022) and
# this might break in future pytorch releases. We chose this API
# as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch.distributed._all_gather_base(gathered, tensor,
group=get_tensor_model_parallel_group())
return gathered
class CudaRNGStatesTracker: class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states. """Tracker for the cuda RNG states.
......
...@@ -10,12 +10,16 @@ def split_tensor_along_last_dim( ...@@ -10,12 +10,16 @@ def split_tensor_along_last_dim(
num_partitions: int, num_partitions: int,
contiguous_split_chunks: bool = False, contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
"""Split a tensor along its last dimension. """ Split a tensor along its last dimension.
Arguments:
tensor: input tensor. Arguments:
num_partitions: number of partitions to split the tensor tensor: input tensor.
contiguous_split_chunks: If True, make each chunk contiguous num_partitions: number of partitions to split the tensor
in memory. contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
""" """
# Get the size and dimension. # Get the size and dimension.
last_dim = tensor.dim() - 1 last_dim = tensor.dim() - 1
...@@ -28,11 +32,64 @@ def split_tensor_along_last_dim( ...@@ -28,11 +32,64 @@ def split_tensor_along_last_dim(
return tensor_list return tensor_list
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
""" Break a tensor into equal 1D chunks across tensor parallel ranks.
Returns a Tensor or View with this rank's portion of the data.
Arguments:
tensor: The tensor to split
Keyword Arguments:
new_buffer (bool): If True, returns a new Tensor.
If False, returns a view into the existing Tensor.
Default is False
"""
partition_size = torch.numel(tensor) // \
get_tensor_model_parallel_world_size()
start_index = partition_size * get_tensor_model_parallel_rank()
end_index = start_index + partition_size
if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data
def gather_split_1d_tensor(tensor):
""" Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor
model parallel ranks.
Returns a new Tensor with the gathered data.
Arguments:
tensor: A Tensor or view of this rank's portion of the data.
"""
numel_gathered = torch.numel(tensor) * \
get_tensor_model_parallel_world_size()
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# TODO: This API is experimental in pytorch (as of Feb 2022) and
# this might break in future pytorch releases. We chose this API
# as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch.distributed._all_gather_base(gathered, tensor,
group=get_tensor_model_parallel_group())
return gathered
class VocabUtility: class VocabUtility:
"""Split the vocabulary into `world_size` chunks and return the """ Split the vocabulary into `world_size` chunks and return the first
first and last index of the vocabulary belonging to the `rank` and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [fist, last)""" partition: Note that indices in [fist, last)
"""
@staticmethod @staticmethod
def vocab_range_from_per_partition_vocab_size( def vocab_range_from_per_partition_vocab_size(
......
...@@ -21,35 +21,6 @@ def divide(numerator, denominator): ...@@ -21,35 +21,6 @@ def divide(numerator, denominator):
return numerator // denominator return numerator // denominator
def split_tensor_into_1d_equal_chunks(tensor):
"""Break a tensor into equal 1D chunks."""
data = tensor.view(-1)
partition_size = (
torch.numel(data) // parallel_state.get_tensor_model_parallel_world_size()
)
start_index = partition_size * parallel_state.get_tensor_model_parallel_rank()
end_index = start_index + partition_size
return data[start_index:end_index]
def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
world_size = parallel_state.get_tensor_model_parallel_world_size()
numel = torch.numel(tensor)
numel_gathered = world_size * numel
gathered = torch.empty(
numel_gathered,
dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
torch.distributed._all_gather_base(
gathered,
tensor,
group=parallel_state.get_tensor_model_parallel_group()
)
return gathered
class GlobalMemoryBuffer: class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations. """Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name Caller should ensure that buffers of the same name
......
...@@ -4,7 +4,8 @@ import time ...@@ -4,7 +4,8 @@ import time
import numpy as np import numpy as np
import torch import torch
from megatron import get_args, get_tokenizer, mpu, print_rank_0 from megatron import get_args, get_tokenizer, print_rank_0
from megatron.core import mpu, tensor_parallel
from megatron.data.dataset_utils import create_masked_lm_predictions, \ from megatron.data.dataset_utils import create_masked_lm_predictions, \
pad_and_convert_to_numpy pad_and_convert_to_numpy
from megatron.data.data_samplers import MegatronPretrainingSampler from megatron.data.data_samplers import MegatronPretrainingSampler
...@@ -57,7 +58,7 @@ def get_ict_batch(data_iterator): ...@@ -57,7 +58,7 @@ def get_ict_batch(data_iterator):
data = None data = None
else: else:
data = next(data_iterator) data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype) data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack. # Unpack.
query_tokens = data_b['query_tokens'].long() query_tokens = data_b['query_tokens'].long()
......
...@@ -8,8 +8,6 @@ import numpy as np ...@@ -8,8 +8,6 @@ import numpy as np
import torch import torch
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import mpu
class BlendableDataset(torch.utils.data.Dataset): class BlendableDataset(torch.utils.data.Dataset):
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
import numpy as np import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron.core import mpu
def build_pretraining_data_loader(dataset, consumed_samples): def build_pretraining_data_loader(dataset, consumed_samples):
......
...@@ -28,9 +28,9 @@ import torch ...@@ -28,9 +28,9 @@ import torch
from megatron import ( from megatron import (
get_args, get_args,
mpu,
print_rank_0 print_rank_0
) )
from megatron.core import mpu
from megatron.data.blendable_dataset import BlendableDataset from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
......
...@@ -8,7 +8,8 @@ import time ...@@ -8,7 +8,8 @@ import time
import numpy as np import numpy as np
import torch import torch
from megatron import mpu, print_rank_0 from megatron import print_rank_0
from megatron.core import mpu
from megatron.data.blendable_dataset import BlendableDataset from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.dataset_utils import get_datasets_weights_and_num_samples from megatron.data.dataset_utils import get_datasets_weights_and_num_samples
from megatron.data.dataset_utils import get_train_valid_test_split_ from megatron.data.dataset_utils import get_train_valid_test_split_
......
...@@ -9,7 +9,8 @@ import random ...@@ -9,7 +9,8 @@ import random
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from megatron import print_rank_0, get_args, get_tokenizer, mpu from megatron import print_rank_0, get_args, get_tokenizer
from megatron.core import tensor_parallel
from megatron.data.biencoder_dataset_utils import make_attention_mask from megatron.data.biencoder_dataset_utils import make_attention_mask
def get_open_retrieval_wiki_dataset(): def get_open_retrieval_wiki_dataset():
...@@ -32,7 +33,7 @@ def get_open_retrieval_batch(data_iterator): ...@@ -32,7 +33,7 @@ def get_open_retrieval_batch(data_iterator):
# Broadcast data. # Broadcast data.
data = None if data_iterator is None else next(data_iterator) data = None if data_iterator is None else next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype) data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack. # Unpack.
row_id = data_b['row_id'].long() row_id = data_b['row_id'].long()
......
...@@ -4,9 +4,10 @@ import time ...@@ -4,9 +4,10 @@ import time
import numpy as np import numpy as np
import torch import torch
from megatron import mpu, print_rank_0 from megatron import print_rank_0
from megatron.core import mpu, tensor_parallel
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
from megatron import get_args, get_tokenizer, print_rank_0, mpu from megatron import get_args, get_tokenizer, print_rank_0
def get_one_epoch_dataloader(dataset, micro_batch_size=None): def get_one_epoch_dataloader(dataset, micro_batch_size=None):
...@@ -47,7 +48,7 @@ def get_ict_batch(data_iterator): ...@@ -47,7 +48,7 @@ def get_ict_batch(data_iterator):
data = None data = None
else: else:
data = next(data_iterator) data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype) data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack. # Unpack.
query_tokens = data_b['query_tokens'].long() query_tokens = data_b['query_tokens'].long()
......
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
import torch import torch
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron.core import mpu
def detach(tensor): def detach(tensor):
...@@ -50,10 +50,10 @@ class OpenRetreivalDataStore(object): ...@@ -50,10 +50,10 @@ class OpenRetreivalDataStore(object):
def load_from_file(self): def load_from_file(self):
"""Populate members from instance saved to file""" """Populate members from instance saved to file"""
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Unpickling BlockData", flush=True) print("\n> Unpickling BlockData", flush=True)
state_dict = pickle.load(open(self.embedding_path, 'rb')) state_dict = pickle.load(open(self.embedding_path, 'rb'))
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print(">> Finished unpickling BlockData\n", flush=True) print(">> Finished unpickling BlockData\n", flush=True)
self.embed_data = state_dict['embed_data'] self.embed_data = state_dict['embed_data']
...@@ -137,7 +137,7 @@ class FaissMIPSIndex(object): ...@@ -137,7 +137,7 @@ class FaissMIPSIndex(object):
except ImportError: except ImportError:
raise Exception("Error: Please install faiss to use FaissMIPSIndex") raise Exception("Error: Please install faiss to use FaissMIPSIndex")
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Building index", flush=True) print("\n> Building index", flush=True)
cpu_index = faiss.IndexFlatIP(self.embed_size) cpu_index = faiss.IndexFlatIP(self.embed_size)
...@@ -149,12 +149,12 @@ class FaissMIPSIndex(object): ...@@ -149,12 +149,12 @@ class FaissMIPSIndex(object):
config.useFloat16 = True config.useFloat16 = True
gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config) gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config)
self.mips_index = faiss.IndexIDMap(gpu_index) self.mips_index = faiss.IndexIDMap(gpu_index)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on GPU", flush=True) print(">> Initialized index on GPU", flush=True)
else: else:
# CPU index supports IDs so wrap with IDMap # CPU index supports IDs so wrap with IDMap
self.mips_index = faiss.IndexIDMap(cpu_index) self.mips_index = faiss.IndexIDMap(cpu_index)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on CPU", flush=True) print(">> Initialized index on CPU", flush=True)
# if we were constructed with a BlockData, then automatically load it # if we were constructed with a BlockData, then automatically load it
...@@ -199,7 +199,7 @@ class FaissMIPSIndex(object): ...@@ -199,7 +199,7 @@ class FaissMIPSIndex(object):
self.mips_index.add_with_ids(embeds_arr, indices_arr) self.mips_index.add_with_ids(embeds_arr, indices_arr)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print(">>> Finished adding block data to index", flush=True) print(">>> Finished adding block data to index", flush=True)
def search_mips_index(self, query_embeds, top_k, reconstruct=True): def search_mips_index(self, query_embeds, top_k, reconstruct=True):
......
...@@ -4,7 +4,7 @@ import torch ...@@ -4,7 +4,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_0
from megatron import mpu from megatron.core import mpu
from megatron.checkpointing import load_biencoder_checkpoint from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch
......
...@@ -14,13 +14,10 @@ from megatron import fused_kernels ...@@ -14,13 +14,10 @@ from megatron import fused_kernels
from megatron import get_adlr_autoresume from megatron import get_adlr_autoresume
from megatron import get_args from megatron import get_args
from megatron import get_tensorboard_writer from megatron import get_tensorboard_writer
from megatron import mpu from megatron.core import mpu, tensor_parallel
from megatron import core
from megatron.arguments import (parse_args, validate_args) from megatron.arguments import (parse_args, validate_args)
from megatron.checkpointing import load_args_from_checkpoint from megatron.checkpointing import load_args_from_checkpoint
from megatron.global_vars import set_global_variables from megatron.global_vars import set_global_variables
from megatron.mpu import (set_tensor_model_parallel_rank,
set_tensor_model_parallel_world_size)
from megatron.model.transformer import bias_dropout_add_fused_train from megatron.model.transformer import bias_dropout_add_fused_train
from megatron.model.fused_bias_gelu import bias_gelu from megatron.model.fused_bias_gelu import bias_gelu
...@@ -65,13 +62,14 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -65,13 +62,14 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
args = get_args() args = get_args()
if args.lazy_mpu_init: if args.lazy_mpu_init:
# TODO is this still a necessary option?
args.use_cpu_initialization=True args.use_cpu_initialization=True
# delayed initialization of DDP-related stuff # delayed initialization of DDP-related stuff
# We only set basic DDP globals # We only set basic DDP globals
set_tensor_model_parallel_world_size(args.tensor_model_parallel_size) mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
# and return function for external DDP manager # and return function for external DDP manager
# to call when it has DDP initialized # to call when it has DDP initialized
set_tensor_model_parallel_rank(args.rank) mpu.set_tensor_model_parallel_rank(args.rank)
return finish_mpu_init return finish_mpu_init
else: else:
# Megatron's MPU is the master. Complete initialization right away. # Megatron's MPU is the master. Complete initialization right away.
...@@ -147,7 +145,7 @@ def _compile_dependencies(): ...@@ -147,7 +145,7 @@ def _compile_dependencies():
def _initialize_distributed(): def _initialize_distributed():
"""Initialize torch.distributed and mpu.""" """Initialize torch.distributed and core model parallel."""
args = get_args() args = get_args()
device_count = torch.cuda.device_count() device_count = torch.cuda.device_count()
...@@ -185,17 +183,14 @@ def _initialize_distributed(): ...@@ -185,17 +183,14 @@ def _initialize_distributed():
print('model parallel is already initialized') print('model parallel is already initialized')
else: else:
mpu.initialize_model_parallel(args.tensor_model_parallel_size, mpu.initialize_model_parallel(args.tensor_model_parallel_size,
args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank)
core.initialize_model_parallel(args.tensor_model_parallel_size,
args.pipeline_model_parallel_size, args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size, args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank) args.pipeline_model_parallel_split_rank)
print(f'> initialized tensor model parallel with size ' if args.rank == 0:
f'{core.get_tensor_model_parallel_world_size()}') print(f'> initialized tensor model parallel with size '
print(f'> initialized pipeline model parallel with size ' f'{mpu.get_tensor_model_parallel_world_size()}')
f'{core.get_pipeline_model_parallel_world_size()}') print(f'> initialized pipeline model parallel with size '
f'{mpu.get_pipeline_model_parallel_world_size()}')
def _init_autoresume(): def _init_autoresume():
...@@ -219,7 +214,7 @@ def _set_random_seed(seed_, data_parallel_random_init=False): ...@@ -219,7 +214,7 @@ def _set_random_seed(seed_, data_parallel_random_init=False):
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.device_count() > 0: if torch.cuda.device_count() > 0:
core.tensor_parallel.model_parallel_cuda_manual_seed(seed) tensor_parallel.model_parallel_cuda_manual_seed(seed)
else: else:
raise ValueError('Seed ({}) should be a positive integer.'.format(seed)) raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import torch import torch
from megatron import get_args from megatron import get_args
from megatron import core from megatron.core import tensor_parallel
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
...@@ -61,7 +61,7 @@ class BertLMHead(MegatronModule): ...@@ -61,7 +61,7 @@ class BertLMHead(MegatronModule):
args = get_args() args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1) tensor_parallel.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
...@@ -110,9 +110,9 @@ def post_language_model_processing(lm_output, pooled_output, ...@@ -110,9 +110,9 @@ def post_language_model_processing(lm_output, pooled_output,
# lm_logits : [s, b, h] and lm_labels: [s, b] # lm_logits : [s, b, h] and lm_labels: [s, b]
if fp16_lm_cross_entropy: if fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half assert lm_logits.dtype == torch.half
lm_loss = core.vocab_parallel_cross_entropy(lm_logits, lm_labels) lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else: else:
lm_loss = core.vocab_parallel_cross_entropy(lm_logits.float(), lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels) lm_labels)
# [s, b] => [b s] # [s, b] => [b s]
lm_loss = lm_loss.transpose(0,1).contiguous() lm_loss = lm_loss.transpose(0,1).contiguous()
......
...@@ -2,11 +2,11 @@ import os ...@@ -2,11 +2,11 @@ import os
import torch import torch
import sys import sys
from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_0, get_tokenizer
from megatron.core import mpu
from megatron.checkpointing import fix_query_key_value_ordering from megatron.checkpointing import fix_query_key_value_ordering
from megatron.checkpointing import get_checkpoint_tracker_filename from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.checkpointing import get_checkpoint_name from megatron.checkpointing import get_checkpoint_name
from megatron import mpu, get_tokenizer
from megatron.model.bert_model import bert_position_ids from megatron.model.bert_model import bert_position_ids
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import torch import torch
from megatron import get_args, print_rank_last from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment