Commit 5684f904 authored by Neel Kant's avatar Neel Kant
Browse files

Full cycle of communication complete. Also added BasicIndexBuilder

parent d4b00be0
...@@ -110,15 +110,14 @@ class AsyncIndexBuilder(object): ...@@ -110,15 +110,14 @@ class AsyncIndexBuilder(object):
def run_async(self): def run_async(self):
while True: while True:
print("Starting (again!)") print("Starting (again!)", flush=True)
self.build_index() self.build_and_save_index()
self.save_index()
self.send_index_ready_signal() self.send_index_ready_signal()
while INDEX_READY == 1: while INDEX_READY == 1:
print("Waiting for new model checkpoint.") print("Waiting for new model checkpoint.", flush=True)
time.sleep(1) time.sleep(5)
self.load_model() self.load_attributes()
def load_attributes(self): def load_attributes(self):
try: try:
...@@ -129,7 +128,7 @@ class AsyncIndexBuilder(object): ...@@ -129,7 +128,7 @@ class AsyncIndexBuilder(object):
self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset())) self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
self.block_data = BlockData() self.block_data = BlockData()
def build_index(self): def build_and_save_index(self):
i = 1 i = 1
total = 0 total = 0
while True: while True:
...@@ -149,7 +148,7 @@ class AsyncIndexBuilder(object): ...@@ -149,7 +148,7 @@ class AsyncIndexBuilder(object):
total += block_indices.size total += block_indices.size
i += 1 i += 1
if i % 2000 == 0: if i % 10 == 0:
print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True) print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
if self.debug: if self.debug:
break break
...@@ -162,27 +161,68 @@ class AsyncIndexBuilder(object): ...@@ -162,27 +161,68 @@ class AsyncIndexBuilder(object):
print_rank_0(">>> training terminated. Returning") print_rank_0(">>> training terminated. Returning")
sys.exit(0) sys.exit(0)
def save_index(self):
self.block_data.save_shard(self.rank) self.block_data.save_shard(self.rank)
torch.distributed.barrier() torch.distributed.barrier()
del self.model del self.model
if self.is_main_builder: if self.is_main_builder:
self.block_data.consolidate_shards_and_save(ignore_shard=self.rank) self.block_data.consolidate_shards_and_save(ignore_shard=self.rank)
else: self.block_data.clear()
self.block_data.clear()
def send_index_ready_signal(self): def send_index_ready_signal(self):
global INDEX_READY global INDEX_READY
if self.is_main_builder: if self.is_main_builder:
INDEX_READY = 1 - INDEX_READY INDEX_READY = 1 - INDEX_READY
print("Switched INDEX_READY", flush=True) print("Switched INDEX_READY", flush=True)
import time
print(time.ctime(time.time()), flush=True)
send_handle = dist.broadcast(INDEX_READY, self.main_builder_idx, async_op=True) send_handle = dist.broadcast(INDEX_READY, self.main_builder_idx, async_op=True)
torch.distributed.barrier(get_index_group()) torch.distributed.barrier(get_index_group())
recv_handle = dist.broadcast(INDEX_READY, 0, async_op=True) recv_handle = dist.broadcast(INDEX_READY, 0, async_op=True)
class BasicIndexBuilder(object):
def __init__(self):
args = get_args()
self.rank = args.rank
self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=False)
self.model.eval()
self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
self.block_data = BlockData()
def build_and_save_index(self):
i = 1
total = 0
while True:
with torch.no_grad():
try:
query_tokens, query_pad_mask, \
block_tokens, block_pad_mask, block_index_data = get_batch(self.dataloader)
except:
break
block_index_data = detach(block_index_data)
block_indices = block_index_data[:, 3]
block_meta = block_index_data[:, :3]
block_logits = detach(self.model(None, None, block_tokens, block_pad_mask, only_block=True))
self.block_data.add_block_data(block_indices, block_logits, block_meta)
total += block_indices.size
i += 1
if i % 2000 == 0:
print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
self.block_data.save_shard(self.rank)
torch.distributed.barrier()
del self.model
if self.rank == 0:
self.block_data.consolidate_shards_and_save(ignore_shard=self.rank)
self.block_data.clear()
def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False): def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False):
args = get_args() args = get_args()
model = get_model(lambda: model_provider(only_query_model, only_block_model)) model = get_model(lambda: model_provider(only_query_model, only_block_model))
...@@ -270,4 +310,8 @@ def get_one_epoch_dataloader(dataset): ...@@ -270,4 +310,8 @@ def get_one_epoch_dataloader(dataset):
if __name__ == "__main__": if __name__ == "__main__":
main() initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
index_builder = BasicIndexBuilder()
index_builder.build_and_save_index()
...@@ -24,6 +24,7 @@ import torch ...@@ -24,6 +24,7 @@ import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import mpu from megatron import mpu
from megatron.mpu.initialize import get_train_group
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
...@@ -118,14 +119,14 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -118,14 +119,14 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
print(' successfully saved {}'.format(checkpoint_name)) print(' successfully saved {}'.format(checkpoint_name))
# Wait so everyone is done (necessary) # Wait so everyone is done (necessary)
torch.distributed.barrier() torch.distributed.barrier(get_train_group())
# And update the latest iteration # And update the latest iteration
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save) tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f: with open(tracker_filename, 'w') as f:
f.write(str(iteration)) f.write(str(iteration))
# Wait so everyone is done (not necessary) # Wait so everyone is done (not necessary)
torch.distributed.barrier() torch.distributed.barrier(get_train_group())
def load_checkpoint(model, optimizer, lr_scheduler): def load_checkpoint(model, optimizer, lr_scheduler):
......
...@@ -163,8 +163,11 @@ class REALMRetriever(MegatronModule): ...@@ -163,8 +163,11 @@ class REALMRetriever(MegatronModule):
def reload_index(self): def reload_index(self):
args = get_args() args = get_args()
print("loading from file", flush=True)
self.block_data = BlockData.load_from_file(args.block_data_path) self.block_data = BlockData.load_from_file(args.block_data_path)
print("resetting index", flush=True)
self.hashed_index.reset_index() self.hashed_index.reset_index()
print("adding block data", flush=True)
self.hashed_index.add_block_embed_data(self.block_data) self.hashed_index.add_block_embed_data(self.block_data)
def prep_query_text_for_retrieval(self, query_text): def prep_query_text_for_retrieval(self, query_text):
......
...@@ -373,13 +373,29 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -373,13 +373,29 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
timers('interval time').start() timers('interval time').start()
report_memory_flag = True report_memory_flag = True
import time
print(">>> going to sleep", flush=True)
time.sleep(10)
print(">>> woke from sleep", flush=True)
print(time.ctime(time.time()), flush=True)
global INDEX_READY global INDEX_READY
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, async_op=True) recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, async_op=True)
print(">>>>>>>> Created recv handle", flush=True)
while iteration < args.train_iters: while iteration < args.train_iters:
if hasattr(model, 'retriever') and INDEX_READY == 1: print("INDEX READY: ", INDEX_READY)
model.retriever.reload_index() if args.max_training_rank is not None and INDEX_READY == 1:
print(">>>>>>> entering the good stuff", flush=True)
true_model = model
if hasattr(true_model, 'module'):
true_model = true_model.module
if hasattr(true_model, 'module'):
true_model = true_model.module
print(">>>>>>> starting to reload index", flush=True)
true_model.retriever.reload_index()
print(">>>>>>> starting to save checkpoint", flush=True)
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, lr_scheduler)
print(">>>>>>> saved checkpoint", flush=True)
if args.rank == 0: if args.rank == 0:
INDEX_READY = 1 - INDEX_READY INDEX_READY = 1 - INDEX_READY
...@@ -387,6 +403,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -387,6 +403,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
send_handle = torch.distributed.broadcast(INDEX_READY, 0, async_op=True) send_handle = torch.distributed.broadcast(INDEX_READY, 0, async_op=True)
torch.distributed.barrier(get_train_group()) torch.distributed.barrier(get_train_group())
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, async_op=True) recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, async_op=True)
else:
print(">>>>>>> moving right along", flush=True)
loss_dict, skipped_iter = train_step(forward_step_func, loss_dict, skipped_iter = train_step(forward_step_func,
train_data_iterator, train_data_iterator,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment