Commit d4b00be0 authored by Neel Kant's avatar Neel Kant
Browse files

Reorganize indexer. Things run up to saving model checkpoint and repeating

parent e338e311
...@@ -3,6 +3,7 @@ import sys ...@@ -3,6 +3,7 @@ import sys
import time import time
import torch import torch
import torch.distributed as dist
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args, get_adlr_autoresume, print_rank_0 from megatron import get_args, get_adlr_autoresume, print_rank_0
...@@ -14,58 +15,128 @@ from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex ...@@ -14,58 +15,128 @@ from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex
from megatron.data.samplers import DistributedBatchSampler from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.model import REALMRetriever from megatron.model import REALMRetriever
from megatron.global_vars import set_global_variables
from megatron.mpu.initialize import get_index_ready, get_index_group, get_train_group
from megatron.mpu.initialize import set_data_parallel_group, set_model_parallel_group, init_realm_groups
from megatron.initialize import init_distributed, _init_autoresume, _set_random_seed, _write_args_to_tensorboard
from megatron.training import get_model from megatron.training import get_model
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from pretrain_bert_ict import get_batch, model_provider from pretrain_bert_ict import get_batch, model_provider
from indexer_utils import set_index_com_file_ready, set_model_com_file_not_ready, check_model_com_file_ready
def test_retriever(): INDEX_READY = None
# TODO: Update this because it's outdated and definitely won't run.
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) def pprint(*args):
print(*args, flush=True)
def initialize_and_run_async_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False, allow_no_cuda=False):
if not allow_no_cuda:
# Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.'
# Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables(extra_args_provider=extra_args_provider,
args_defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
# instead of _initialize_distributed()
init_distributed()
setup_realm_groups_and_vars()
global INDEX_READY
INDEX_READY = get_index_ready()
pprint('finished setting up groups')
# Autoresume
_init_autoresume()
pprint('finished setting up autoresume')
# Random seeds for reproducibility.
args = get_args() args = get_args()
model = load_ict_checkpoint() if args.rank == 0:
model.eval() pprint('> setting random seeds to {} ...'.format(args.seed))
dataset = get_ict_dataset() _set_random_seed(args.seed)
block_data = BlockData.load_from_file(args.block_data_path) # Write arguments to tensorboard.
mips_index = FaissMIPSIndex('flat_ip', 128) _write_args_to_tensorboard()
mips_index.add_block_embed_data(block_data) pprint('finished writing args to tensorboard')
retriever = REALMRetriever(model, dataset, block_data, mips_index, top_k=5)
strs = [ torch.distributed.barrier()
"The last monarch from the house of windsor",
"married to Elvis Presley",
"tallest building in the world today",
"who makes graphics cards"
]
for s in strs: if args.rank < args.max_training_rank:
retriever.retrieve_evidence_blocks_text(s) torch.distributed.barrier(get_train_group())
pprint("All trainers ready.")
return
else:
runner = AsyncIndexBuilder(args.rank)
torch.distributed.barrier(get_index_group())
pprint("All indexers ready.")
runner.run_async()
def main(): def setup_realm_groups_and_vars():
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args() args = get_args()
while True: world_size = dist.get_world_size()
max_training_rank = args.max_training_rank
# assuming no model parallelism right now
set_model_parallel_group(dist.new_group([args.rank]))
init_realm_groups(max_training_rank, world_size)
if args.rank < max_training_rank:
set_data_parallel_group(get_train_group())
else:
set_data_parallel_group(get_index_group())
class AsyncIndexBuilder(object):
def __init__(self, rank):
self.rank = rank
args = get_args()
self.is_main_builder = self.rank == args.max_training_rank
self.main_builder_idx = args.max_training_rank
self.debug = args.debug
self.model = None
self.dataloader = None
self.block_data = None
self.load_attributes()
global INDEX_READY
INDEX_READY = get_index_ready()
def run_async(self):
while True:
print("Starting (again!)")
self.build_index()
self.save_index()
self.send_index_ready_signal()
while INDEX_READY == 1:
print("Waiting for new model checkpoint.")
time.sleep(1)
self.load_model()
def load_attributes(self):
try: try:
model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=True) self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=True)
except: except:
model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=False) self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=False)
model.eval() self.model.eval()
dataset = get_ict_dataset() self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
data_iter = iter(get_one_epoch_dataloader(dataset)) self.block_data = BlockData()
all_block_data = BlockData()
def build_index(self):
i = 1 i = 1
total = 0 total = 0
while True: while True:
with torch.no_grad(): with torch.no_grad():
try: try:
query_tokens, query_pad_mask, \ query_tokens, query_pad_mask, \
block_tokens, block_pad_mask, block_index_data = get_batch(data_iter) block_tokens, block_pad_mask, block_index_data = get_batch(self.dataloader)
except: except:
break break
...@@ -73,30 +144,16 @@ def main(): ...@@ -73,30 +144,16 @@ def main():
block_indices = block_index_data[:, 3] block_indices = block_index_data[:, 3]
block_meta = block_index_data[:, :3] block_meta = block_index_data[:, :3]
block_logits = detach(model(None, None, block_tokens, block_pad_mask, only_block=True)) block_logits = detach(self.model(None, None, block_tokens, block_pad_mask, only_block=True))
all_block_data.add_block_data(block_indices, block_logits, block_meta) self.block_data.add_block_data(block_indices, block_logits, block_meta)
total += block_indices.size total += block_indices.size
i += 1 i += 1
if i % 2000 == 0: if i % 2000 == 0:
print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True) print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
if args.debug: if self.debug:
break break
all_block_data.save_shard(args.rank)
torch.distributed.barrier()
del model
if args.rank == 0:
all_block_data.consolidate_shards_and_save()
else:
all_block_data.clear()
set_index_com_file_ready()
torch.distributed.barrier()
if args.async_indexer:
while not check_model_com_file_ready():
time.sleep(5)
autoresume = get_adlr_autoresume() autoresume = get_adlr_autoresume()
if autoresume.termination_requested(): if autoresume.termination_requested():
print_rank_0(">>> autoresume termination request found!") print_rank_0(">>> autoresume termination request found!")
...@@ -105,17 +162,36 @@ def main(): ...@@ -105,17 +162,36 @@ def main():
print_rank_0(">>> training terminated. Returning") print_rank_0(">>> training terminated. Returning")
sys.exit(0) sys.exit(0)
set_model_com_file_not_ready() def save_index(self):
self.block_data.save_shard(self.rank)
torch.distributed.barrier()
del self.model
if self.is_main_builder:
self.block_data.consolidate_shards_and_save(ignore_shard=self.rank)
else:
self.block_data.clear()
def send_index_ready_signal(self):
global INDEX_READY
if self.is_main_builder:
INDEX_READY = 1 - INDEX_READY
print("Switched INDEX_READY", flush=True)
send_handle = dist.broadcast(INDEX_READY, self.main_builder_idx, async_op=True)
torch.distributed.barrier(get_index_group())
recv_handle = dist.broadcast(INDEX_READY, 0, async_op=True)
def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False): def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False):
args = get_args() args = get_args()
model = get_model(lambda: model_provider(only_query_model, only_block_model)) model = get_model(lambda: model_provider(only_query_model, only_block_model))
load_path = args.load if from_realm_chkpt else args.ict_load
if isinstance(model, torchDDP): if isinstance(model, torchDDP):
model = model.module model = model.module
load_path = args.load if from_realm_chkpt else args.ict_load
tracker_filename = get_checkpoint_tracker_filename(load_path) tracker_filename = get_checkpoint_tracker_filename(load_path)
with open(tracker_filename, 'r') as f: with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip()) iteration = int(f.read().strip())
...@@ -174,7 +250,9 @@ def get_one_epoch_dataloader(dataset): ...@@ -174,7 +250,9 @@ def get_one_epoch_dataloader(dataset):
args = get_args() args = get_args()
world_size = mpu.get_data_parallel_world_size() world_size = mpu.get_data_parallel_world_size()
print(world_size, flush=True)
rank = mpu.get_data_parallel_rank() rank = mpu.get_data_parallel_rank()
print(rank, flush=True)
global_batch_size = args.batch_size * world_size global_batch_size = args.batch_size * world_size
num_workers = args.num_workers num_workers = args.num_workers
......
import os
import time
import torch
import torch.distributed as dist
from megatron import get_args
from megatron.global_vars import set_global_variables
from megatron.initialize import init_distributed, _init_autoresume, _set_random_seed, _write_args_to_tensorboard
from megatron.mpu.initialize import set_data_parallel_group, set_model_parallel_group
# Example: 4x8 for training, 1x8 for indexing.
# Assign args.rank < 32 to TRAIN_PROCESS_GROUP, args.rank >= to INDEX_PROCESS_GROUP
# can manually assign _MODEL_PARALLEL_GROUP to args.rank, _DATA_PARALLEL_GROUP to train or index process group
# for both, create a torchDDP accordingly because you need to set up the model to be data-parallel on each.
INDEX_READY = None
TRAIN_GROUP = None
INDEX_GROUP = None
# flow:
# index builder finishes first and sets INDEX_READY = 1.
# communicates by dist.broadcast(INDEX_READY, src=min_index_rank)
# index builder is now waiting for INDEX_READY = 0.
#
# at every iteration, trainer checks INDEX_READY = 1.
# when INDEX_READY = 1, reload the index, save model checkpoint and set INDEX_READY = 0.
# once done, trainer does dist.broadcast(INDEX_READY, src=min_train_rank)
# when INDEX_READY = 0, indexer loads up model checkpoint and begins again.
def pprint(*args):
print(*args, flush=True)
def initialize_and_run_async_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False, allow_no_cuda=False):
if not allow_no_cuda:
# Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.'
# Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables(extra_args_provider=extra_args_provider,
args_defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
# instead of _initialize_distributed()
init_distributed()
setup_groups()
pprint('finished setting up groups')
# Autoresume
_init_autoresume()
pprint('finished setting up autoresume')
# Random seeds for reproducibility.
args = get_args()
if args.rank == 0:
pprint('> setting random seeds to {} ...'.format(args.seed))
# _set_random_seed(args.seed)
# Write arguments to tensorboard.
_write_args_to_tensorboard()
pprint('finished writing args to tensorboard')
torch.distributed.barrier()
global INDEX_READY
INDEX_READY = torch.zeros(1).cuda()
if args.rank < args.max_training_rank:
runner = AsyncREALMTrainer(args.rank)
torch.distributed.barrier(TRAIN_GROUP)
pprint("All trainers ready.")
runner.dummy_train_model()
else:
runner = AsyncIndexBuilder(args.rank)
torch.distributed.barrier(INDEX_GROUP)
pprint("All indexers ready.")
runner.dummy_build_index()
def setup_groups():
args = get_args()
world_size = dist.get_world_size()
max_training_rank = args.max_training_rank
# assuming no model parallelism right now
set_model_parallel_group(args.rank)
global TRAIN_GROUP
global INDEX_GROUP
# important for batching and whatnot
TRAIN_GROUP = dist.new_group(list(range(max_training_rank)))
INDEX_GROUP = dist.new_group(list(range(max_training_rank, world_size)))
if args.rank > max_training_rank:
set_data_parallel_group(INDEX_GROUP)
else:
set_data_parallel_group(TRAIN_GROUP)
class AsyncIndexBuilder(object):
def __init__(self, rank):
self.rank = rank
pprint("My rank: ", self.rank)
def dummy_build_index(self):
start_time = time.time()
pprint("START: {}".format(time.ctime(start_time)))
pprint("-" * 100)
for i in range(5):
# simulating building the index which takes 20 seconds
time.sleep(10)
pprint('built the index. Time: {}'.format(time.ctime(time.time())))
args = get_args()
global INDEX_READY
if self.rank == args.max_training_rank:
# broadcasting that the index is ready
INDEX_READY = 1 - INDEX_READY
send_handle = dist.broadcast(INDEX_READY, args.max_training_rank, async_op=True)
pprint("Broadcasted index ready = ", INDEX_READY)
else:
send_recv_handle = dist.broadcast(INDEX_READY, args.max_training_rank, async_op=True)
torch.distributed.barrier(INDEX_GROUP)
pprint("Synced after broadcasting")
recv_handle = dist.broadcast(INDEX_READY, 0, async_op=True)
while INDEX_READY == 1:
pprint('waiting for new model. Time: {}'.format(time.ctime(time.time())))
time.sleep(1)
class AsyncREALMTrainer(object):
def __init__(self, rank):
self.rank = rank
pprint("My rank: ", self.rank)
def dummy_train_model(self):
start_time = time.time()
pprint("START: {}".format(time.ctime(start_time)))
pprint("-" * 100)
args = get_args()
for i in range(5):
global INDEX_READY
recv_handle = dist.broadcast(INDEX_READY, args.max_training_rank, async_op=True)
while True:
if INDEX_READY == 1:
break
assert self.rank != args.max_training_rank
pprint('waiting for new index. Time: {}'.format(time.ctime(time.time())))
time.sleep(2)
# INDEX_READY is 1
if self.rank == 0:
INDEX_READY = 1 - INDEX_READY
send_handle = dist.broadcast(INDEX_READY, 0, async_op=True)
pprint("Broadcasted index ready = ", INDEX_READY)
else:
send_recv_handle = dist.broadcast(INDEX_READY, 0, async_op=True)
torch.distributed.barrier(TRAIN_GROUP)
pprint("Synced after broadcasting")
if __name__ == "__main__":
initialize_and_run_async_megatron(args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
...@@ -187,8 +187,8 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo ...@@ -187,8 +187,8 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
# parallel case # parallel case
counts = torch.cuda.LongTensor([1]) counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size( #assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group()) # group=mpu.get_data_parallel_group())
# Load indexed dataset. # Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format( print_rank_0(' > loading indexed mapping from {}'.format(
......
...@@ -26,6 +26,10 @@ _MODEL_PARALLEL_GROUP = None ...@@ -26,6 +26,10 @@ _MODEL_PARALLEL_GROUP = None
# Data parallel group that the current rank belongs to. # Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None
_TRAIN_GROUP = None
_INDEX_GROUP = None
_INDEX_READY = None
# These values enable us to change the mpu sizes on the fly. # These values enable us to change the mpu sizes on the fly.
_MPU_WORLD_SIZE = None _MPU_WORLD_SIZE = None
_MPU_RANK = None _MPU_RANK = None
...@@ -105,8 +109,10 @@ def set_model_parallel_group(group): ...@@ -105,8 +109,10 @@ def set_model_parallel_group(group):
def get_data_parallel_group(): def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to.""" """Get the data parallel group the caller rank belongs to."""
#print(">>> yeah this function works.")
assert _DATA_PARALLEL_GROUP is not None, \ assert _DATA_PARALLEL_GROUP is not None, \
'data parallel group is not initialized' 'data parallel group is not initialized'
#print(_DATA_PARALLEL_GROUP)
return _DATA_PARALLEL_GROUP return _DATA_PARALLEL_GROUP
...@@ -114,6 +120,7 @@ def set_data_parallel_group(group): ...@@ -114,6 +120,7 @@ def set_data_parallel_group(group):
global _DATA_PARALLEL_GROUP global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, \ assert _DATA_PARALLEL_GROUP is None, \
'data parallel group has already been initialized' 'data parallel group has already been initialized'
print(">>> setting data parallel group: ", group, flush=True)
_DATA_PARALLEL_GROUP = group _DATA_PARALLEL_GROUP = group
...@@ -169,3 +176,30 @@ def destroy_model_parallel(): ...@@ -169,3 +176,30 @@ def destroy_model_parallel():
_MODEL_PARALLEL_GROUP = None _MODEL_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None
def init_realm_groups(max_training_rank, world_size):
global _TRAIN_GROUP
_TRAIN_GROUP = torch.distributed.new_group(list(range(max_training_rank)))
global _INDEX_GROUP
_INDEX_GROUP = torch.distributed.new_group(list(range(max_training_rank, world_size)))
global _INDEX_READY
_INDEX_READY = torch.zeros(1).cuda()
def get_train_group():
global _TRAIN_GROUP
assert _TRAIN_GROUP is not None
return _TRAIN_GROUP
def get_index_group():
global _INDEX_GROUP
assert _INDEX_GROUP is not None
return _INDEX_GROUP
def get_index_ready():
global _INDEX_READY
assert _INDEX_READY is not None
return _INDEX_READY
...@@ -36,14 +36,18 @@ from megatron.initialize import initialize_megatron ...@@ -36,14 +36,18 @@ from megatron.initialize import initialize_megatron
from megatron.learning_rates import AnnealingLR from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization from megatron.model import get_params_for_weight_decay_optimization
from megatron.mpu.initialize import get_index_ready, get_train_group
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import make_data_loader from megatron.utils import make_data_loader
from megatron.utils import report_memory from megatron.utils import report_memory
from indexer_utils import check_index_com_file_ready, set_index_com_file_not_ready, set_model_com_file_ready
INDEX_READY = None
def pretrain(train_valid_test_dataset_provider, model_provider, def pretrain(train_valid_test_dataset_provider, model_provider,
forward_step_func, extra_args_provider=None, args_defaults={}): forward_step_func, extra_args_provider=None, args_defaults={},
initializer_func=None):
"""Main training program. """Main training program.
This function will run the followings in the order provided: This function will run the followings in the order provided:
...@@ -69,8 +73,15 @@ def pretrain(train_valid_test_dataset_provider, model_provider, ...@@ -69,8 +73,15 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
""" """
# Initalize and get arguments, timers, and Tensorboard writer. # Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron(extra_args_provider=extra_args_provider, if initializer_func is None:
args_defaults=args_defaults) initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
else:
initializer_func(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
global INDEX_READY
INDEX_READY = get_index_ready()
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -250,7 +261,6 @@ def backward_step(optimizer, model, loss): ...@@ -250,7 +261,6 @@ def backward_step(optimizer, model, loss):
else: else:
optimizer.clip_master_grads(args.clip_grad) optimizer.clip_master_grads(args.clip_grad)
ran_backward_once = False
def train_step(forward_step_func, data_iterator, def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler): model, optimizer, lr_scheduler):
...@@ -363,15 +373,20 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -363,15 +373,20 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
timers('interval time').start() timers('interval time').start()
report_memory_flag = True report_memory_flag = True
global INDEX_READY
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, async_op=True)
while iteration < args.train_iters: while iteration < args.train_iters:
if hasattr(model, 'retriever'): if hasattr(model, 'retriever') and INDEX_READY == 1:
new_index_ready = check_index_com_file_ready() model.retriever.reload_index()
if new_index_ready: save_checkpoint(iteration, model, optimizer, lr_scheduler)
torch.distributed.barrier()
model.retriever.reload_index() if args.rank == 0:
set_index_com_file_not_ready() INDEX_READY = 1 - INDEX_READY
save_checkpoint(iteration, model, optimizer, lr_scheduler) print("Switched index ready", flush=True)
set_model_com_file_ready() send_handle = torch.distributed.broadcast(INDEX_READY, 0, async_op=True)
torch.distributed.barrier(get_train_group())
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, async_op=True)
loss_dict, skipped_iter = train_step(forward_step_func, loss_dict, skipped_iter = train_step(forward_step_func,
train_data_iterator, train_data_iterator,
......
...@@ -27,6 +27,7 @@ from megatron.data.dataset_utils import build_train_valid_test_datasets ...@@ -27,6 +27,7 @@ from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import REALMBertModel, REALMRetriever from megatron.model import REALMBertModel, REALMRetriever
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
from indexer import initialize_and_run_async_megatron
num_batches = 0 num_batches = 0
...@@ -177,4 +178,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -177,4 +178,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__": if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step, pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'},
initializer_func=initialize_and_run_async_megatron)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment