Commit edaf2aab authored by Neel Kant's avatar Neel Kant
Browse files

Add indexer_async

parent 0f0f60aa
import os
import time
import torch
import torch.distributed as dist
from megatron import get_args
from megatron.global_vars import set_global_variables
from megatron.initialize import init_distributed, _init_autoresume, _set_random_seed, _write_args_to_tensorboard
from megatron.mpu.initialize import set_data_parallel_group, set_model_parallel_group
# Example: 4x8 for training, 1x8 for indexing.
# Assign args.rank < 32 to TRAIN_PROCESS_GROUP, args.rank >= to INDEX_PROCESS_GROUP
# can manually assign _MODEL_PARALLEL_GROUP to args.rank, _DATA_PARALLEL_GROUP to train or index process group
# for both, create a torchDDP accordingly because you need to set up the model to be data-parallel on each.
INDEX_READY = None
TRAIN_GROUP = None
INDEX_GROUP = None
# flow:
# index builder finishes first and sets INDEX_READY = 1.
# communicates by dist.broadcast(INDEX_READY, src=min_index_rank)
# index builder is now waiting for INDEX_READY = 0.
#
# at every iteration, trainer checks INDEX_READY = 1.
# when INDEX_READY = 1, reload the index, save model checkpoint and set INDEX_READY = 0.
# once done, trainer does dist.broadcast(INDEX_READY, src=min_train_rank)
# when INDEX_READY = 0, indexer loads up model checkpoint and begins again.
def pprint(*args):
print(*args, flush=True)
def initialize_and_run_async_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False, allow_no_cuda=False):
if not allow_no_cuda:
# Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.'
# Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables(extra_args_provider=extra_args_provider,
args_defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
# instead of _initialize_distributed()
init_distributed()
setup_groups()
pprint('finished setting up groups')
# Autoresume
_init_autoresume()
pprint('finished setting up autoresume')
# Random seeds for reproducibility.
args = get_args()
if args.rank == 0:
pprint('> setting random seeds to {} ...'.format(args.seed))
# _set_random_seed(args.seed)
# Write arguments to tensorboard.
_write_args_to_tensorboard()
pprint('finished writing args to tensorboard')
torch.distributed.barrier()
global INDEX_READY
INDEX_READY = torch.zeros(1).cuda()
if args.rank < args.max_training_rank:
runner = AsyncREALMTrainer(args.rank)
torch.distributed.barrier(TRAIN_GROUP)
pprint("All trainers ready.")
runner.dummy_train_model()
else:
runner = AsyncIndexBuilder(args.rank)
torch.distributed.barrier(INDEX_GROUP)
pprint("All indexers ready.")
runner.dummy_build_index()
def setup_groups():
args = get_args()
world_size = dist.get_world_size()
max_training_rank = args.max_training_rank
# assuming no model parallelism right now
set_model_parallel_group(args.rank)
global TRAIN_GROUP
global INDEX_GROUP
# important for batching and whatnot
TRAIN_GROUP = dist.new_group(list(range(max_training_rank)))
INDEX_GROUP = dist.new_group(list(range(max_training_rank, world_size)))
if args.rank > max_training_rank:
set_data_parallel_group(INDEX_GROUP)
else:
set_data_parallel_group(TRAIN_GROUP)
class AsyncIndexBuilder(object):
def __init__(self, rank):
self.rank = rank
pprint("My rank: ", self.rank)
def dummy_build_index(self):
start_time = time.time()
pprint("START: {}".format(time.ctime(start_time)))
pprint("-" * 100)
for i in range(5):
# simulating building the index which takes 20 seconds
time.sleep(20)
pprint('built the index. Time: {}'.format(time.ctime(time.time())))
args = get_args()
global INDEX_READY
if self.rank == args.max_training_rank:
# broadcasting that the index is ready
INDEX_READY = 1 - INDEX_READY
send_handle = dist.broadcast(INDEX_READY, args.max_training_rank, async_op=True)
pprint("Broadcasted index ready = ", INDEX_READY)
torch.distributed.barrier(INDEX_GROUP)
recv_handle = dist.broadcast(INDEX_READY, 0, async_op=True)
while INDEX_READY == 1:
pprint('waiting for new model. Time: {}'.format(time.ctime(time.time())))
time.sleep(1)
class AsyncREALMTrainer(object):
def __init__(self, rank):
self.rank = rank
pprint("My rank: ", self.rank)
def dummy_train_model(self):
start_time = time.time()
pprint("START: {}".format(time.ctime(start_time)))
pprint("-" * 100)
args = get_args()
for i in range(5):
global INDEX_READY
recv_handle = dist.broadcast(INDEX_READY, args.max_training_rank, async_op=True)
while True:
if INDEX_READY == 1:
break
assert self.rank != args.max_training_rank
pprint('waiting for new index. Time: {}'.format(time.ctime(time.time())))
time.sleep(2)
# INDEX_READY is 1
if self.rank == 0:
INDEX_READY = 1 - INDEX_READY
send_handle = dist.broadcast(INDEX_READY, self.rank, async_op=True)
pprint("Broadcasted index ready = ", INDEX_READY)
torch.distributed.barrier(TRAIN_GROUP)
if __name__ == "__main__":
initialize_and_run_async_megatron(args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment