Commit e338e311 authored by Neel Kant's avatar Neel Kant
Browse files

Indexer_async works in theory

parent edaf2aab
......@@ -20,23 +20,6 @@ from pretrain_bert_ict import get_batch, model_provider
from indexer_utils import set_index_com_file_ready, set_model_com_file_not_ready, check_model_com_file_ready
# TODO re: main()
# consider broadcasting/all-reducing all in memory rather than using the filesystem
# create a different process group in the same nccl world - don't have to use chkpts on disc or transfer things on disc
# torch distributed new group, constains a list of rank, gives back a group which I can hand to the collective operations
# create a training process group, indexing process group
# pass the training group to the distributed DDP, instead of the large world process group
# use indexing process group for the shard-combining
# communication group between process "8" and process "0" which tells training group that there's a new index
# also, process 0 sends process 8 the new model
# if i want to launch a separate process for indexing, may have to work with environment variables to
# allocate the resources well. Have to subsequently assign the correct gpus to the indexing job
# consider initializing everything in a single group and break off processes based on the ranks
# for debugging purposes, make it so that the training process group checks every some number of intervals
# and if it isn't ready, then wait so that it's consistent. Start with using the filesystem
def test_retriever():
# TODO: Update this because it's outdated and definitely won't run.
initialize_megatron(extra_args_provider=None,
......@@ -66,9 +49,11 @@ def main():
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args()
ran_once = False
while True:
model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=ran_once)
try:
model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=True)
except:
model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=False)
model.eval()
dataset = get_ict_dataset()
data_iter = iter(get_one_epoch_dataloader(dataset))
......@@ -93,7 +78,7 @@ def main():
total += block_indices.size
i += 1
if i % 20 == 0:
if i % 2000 == 0:
print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
if args.debug:
break
......@@ -107,7 +92,6 @@ def main():
else:
all_block_data.clear()
ran_once = True
set_index_com_file_ready()
torch.distributed.barrier()
if args.async_indexer:
......
......@@ -111,7 +111,7 @@ class AsyncIndexBuilder(object):
pprint("-" * 100)
for i in range(5):
# simulating building the index which takes 20 seconds
time.sleep(20)
time.sleep(10)
pprint('built the index. Time: {}'.format(time.ctime(time.time())))
args = get_args()
......@@ -121,8 +121,11 @@ class AsyncIndexBuilder(object):
INDEX_READY = 1 - INDEX_READY
send_handle = dist.broadcast(INDEX_READY, args.max_training_rank, async_op=True)
pprint("Broadcasted index ready = ", INDEX_READY)
else:
send_recv_handle = dist.broadcast(INDEX_READY, args.max_training_rank, async_op=True)
torch.distributed.barrier(INDEX_GROUP)
pprint("Synced after broadcasting")
recv_handle = dist.broadcast(INDEX_READY, 0, async_op=True)
while INDEX_READY == 1:
......@@ -154,12 +157,14 @@ class AsyncREALMTrainer(object):
# INDEX_READY is 1
if self.rank == 0:
INDEX_READY = 1 - INDEX_READY
send_handle = dist.broadcast(INDEX_READY, self.rank, async_op=True)
send_handle = dist.broadcast(INDEX_READY, 0, async_op=True)
pprint("Broadcasted index ready = ", INDEX_READY)
else:
send_recv_handle = dist.broadcast(INDEX_READY, 0, async_op=True)
torch.distributed.barrier(TRAIN_GROUP)
pprint("Synced after broadcasting")
if __name__ == "__main__":
initialize_and_run_async_megatron(args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
INDEX_COM_FILE = 'ready.index'
MODEL_COM_FILE = 'ready.model'
def set_index_com_file_not_ready():
with open(INDEX_COM_FILE, 'w') as com_file:
com_file.write('0')
def set_index_com_file_ready():
with open(INDEX_COM_FILE, 'w') as com_file:
com_file.write('1')
def check_index_com_file_ready():
if not os.path.exists(INDEX_COM_FILE):
set_index_com_file_not_ready()
with open(INDEX_COM_FILE, 'r') as com_file:
return bool(com_file.readline())
def set_model_com_file_not_ready():
with open(MODEL_COM_FILE, 'w') as com_file:
com_file.write('0')
def set_model_com_file_ready():
with open(MODEL_COM_FILE, 'w') as com_file:
com_file.write('1')
def check_model_com_file_ready():
if not os.path.exists(MODEL_COM_FILE):
set_index_com_file_not_ready()
with open(MODEL_COM_FILE, 'r') as com_file:
return bool(com_file.readline())
......@@ -195,6 +195,7 @@ def _add_training_args(parser):
'by this value.')
group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory.')
group.add_argument('--max-training-rank', type=int, default=None)
return parser
......
......@@ -61,8 +61,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
_write_args_to_tensorboard()
def _initialize_distributed():
"""Initialize torch.distributed and mpu."""
def init_distributed():
args = get_args()
device_count = torch.cuda.device_count()
......@@ -102,6 +101,13 @@ def _initialize_distributed():
world_size=args.world_size, rank=args.rank,
init_method=init_method)
def _initialize_distributed():
"""Initialize torch.distributed and mpu."""
init_distributed()
args = get_args()
device_count = torch.cuda.device_count()
# Set the model-parallel / data-parallel communicators.
if device_count > 0:
mpu.initialize_model_parallel(args.model_parallel_size)
......
......@@ -96,6 +96,13 @@ def get_model_parallel_group():
return _MODEL_PARALLEL_GROUP
def set_model_parallel_group(group):
global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, \
'model parallel group has already been initialized'
_MODEL_PARALLEL_GROUP = group
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, \
......@@ -103,6 +110,13 @@ def get_data_parallel_group():
return _DATA_PARALLEL_GROUP
def set_data_parallel_group(group):
global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, \
'data parallel group has already been initialized'
_DATA_PARALLEL_GROUP = group
def set_model_parallel_world_size(world_size):
"""Set the model parallel size"""
global _MPU_WORLD_SIZE
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment