"torchvision/vscode:/vscode.git/clone" did not exist on "e2e511be5918fc61008f384d88bce06bd61070da"
Commit 5684f904 authored by Neel Kant's avatar Neel Kant
Browse files

Full cycle of communication complete. Also added BasicIndexBuilder

parent d4b00be0
......@@ -110,15 +110,14 @@ class AsyncIndexBuilder(object):
def run_async(self):
while True:
print("Starting (again!)")
self.build_index()
self.save_index()
print("Starting (again!)", flush=True)
self.build_and_save_index()
self.send_index_ready_signal()
while INDEX_READY == 1:
print("Waiting for new model checkpoint.")
time.sleep(1)
print("Waiting for new model checkpoint.", flush=True)
time.sleep(5)
self.load_model()
self.load_attributes()
def load_attributes(self):
try:
......@@ -129,7 +128,7 @@ class AsyncIndexBuilder(object):
self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
self.block_data = BlockData()
def build_index(self):
def build_and_save_index(self):
i = 1
total = 0
while True:
......@@ -149,7 +148,7 @@ class AsyncIndexBuilder(object):
total += block_indices.size
i += 1
if i % 2000 == 0:
if i % 10 == 0:
print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
if self.debug:
break
......@@ -162,27 +161,68 @@ class AsyncIndexBuilder(object):
print_rank_0(">>> training terminated. Returning")
sys.exit(0)
def save_index(self):
self.block_data.save_shard(self.rank)
torch.distributed.barrier()
del self.model
if self.is_main_builder:
self.block_data.consolidate_shards_and_save(ignore_shard=self.rank)
else:
self.block_data.clear()
self.block_data.clear()
def send_index_ready_signal(self):
global INDEX_READY
if self.is_main_builder:
INDEX_READY = 1 - INDEX_READY
print("Switched INDEX_READY", flush=True)
import time
print(time.ctime(time.time()), flush=True)
send_handle = dist.broadcast(INDEX_READY, self.main_builder_idx, async_op=True)
torch.distributed.barrier(get_index_group())
recv_handle = dist.broadcast(INDEX_READY, 0, async_op=True)
class BasicIndexBuilder(object):
def __init__(self):
args = get_args()
self.rank = args.rank
self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=False)
self.model.eval()
self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
self.block_data = BlockData()
def build_and_save_index(self):
i = 1
total = 0
while True:
with torch.no_grad():
try:
query_tokens, query_pad_mask, \
block_tokens, block_pad_mask, block_index_data = get_batch(self.dataloader)
except:
break
block_index_data = detach(block_index_data)
block_indices = block_index_data[:, 3]
block_meta = block_index_data[:, :3]
block_logits = detach(self.model(None, None, block_tokens, block_pad_mask, only_block=True))
self.block_data.add_block_data(block_indices, block_logits, block_meta)
total += block_indices.size
i += 1
if i % 2000 == 0:
print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
self.block_data.save_shard(self.rank)
torch.distributed.barrier()
del self.model
if self.rank == 0:
self.block_data.consolidate_shards_and_save(ignore_shard=self.rank)
self.block_data.clear()
def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False):
args = get_args()
model = get_model(lambda: model_provider(only_query_model, only_block_model))
......@@ -270,4 +310,8 @@ def get_one_epoch_dataloader(dataset):
if __name__ == "__main__":
main()
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
index_builder = BasicIndexBuilder()
index_builder.build_and_save_index()
......@@ -24,6 +24,7 @@ import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import mpu
from megatron.mpu.initialize import get_train_group
from megatron import get_args
from megatron import print_rank_0
......@@ -118,14 +119,14 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
print(' successfully saved {}'.format(checkpoint_name))
# Wait so everyone is done (necessary)
torch.distributed.barrier()
torch.distributed.barrier(get_train_group())
# And update the latest iteration
if torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f:
f.write(str(iteration))
# Wait so everyone is done (not necessary)
torch.distributed.barrier()
torch.distributed.barrier(get_train_group())
def load_checkpoint(model, optimizer, lr_scheduler):
......
......@@ -163,8 +163,11 @@ class REALMRetriever(MegatronModule):
def reload_index(self):
args = get_args()
print("loading from file", flush=True)
self.block_data = BlockData.load_from_file(args.block_data_path)
print("resetting index", flush=True)
self.hashed_index.reset_index()
print("adding block data", flush=True)
self.hashed_index.add_block_embed_data(self.block_data)
def prep_query_text_for_retrieval(self, query_text):
......
......@@ -373,13 +373,29 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
timers('interval time').start()
report_memory_flag = True
import time
print(">>> going to sleep", flush=True)
time.sleep(10)
print(">>> woke from sleep", flush=True)
print(time.ctime(time.time()), flush=True)
global INDEX_READY
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, async_op=True)
print(">>>>>>>> Created recv handle", flush=True)
while iteration < args.train_iters:
if hasattr(model, 'retriever') and INDEX_READY == 1:
model.retriever.reload_index()
print("INDEX READY: ", INDEX_READY)
if args.max_training_rank is not None and INDEX_READY == 1:
print(">>>>>>> entering the good stuff", flush=True)
true_model = model
if hasattr(true_model, 'module'):
true_model = true_model.module
if hasattr(true_model, 'module'):
true_model = true_model.module
print(">>>>>>> starting to reload index", flush=True)
true_model.retriever.reload_index()
print(">>>>>>> starting to save checkpoint", flush=True)
save_checkpoint(iteration, model, optimizer, lr_scheduler)
print(">>>>>>> saved checkpoint", flush=True)
if args.rank == 0:
INDEX_READY = 1 - INDEX_READY
......@@ -387,6 +403,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
send_handle = torch.distributed.broadcast(INDEX_READY, 0, async_op=True)
torch.distributed.barrier(get_train_group())
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, async_op=True)
else:
print(">>>>>>> moving right along", flush=True)
loss_dict, skipped_iter = train_step(forward_step_func,
train_data_iterator,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment