"docs/references/setup_github_runner.md" did not exist on "d738ab52f86f65806fc6549270a704e34aae5a32"
Commit 15d0d55b authored by Neel Kant's avatar Neel Kant
Browse files

Add primitive filesystem-based IPC for indexer and trainer jobs

parent 0f5e2809
......@@ -67,7 +67,6 @@ def print_accuracy_stats(name, gold_indices, estimated_indices):
print('{:20s} First missing: {:4d} | All equal: {:4d} | Mixed: {:4d}'.format(name, *[results[s] for s in result_strs]))
def create_and_test_gold(d, k, embeds, queries):
times = [time.time()]
gold_idx = index_factory(d, 'Flat')
......
import os
import time
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
......@@ -15,6 +18,7 @@ from pretrain_bert_ict import get_batch, model_provider
def test_retriever():
# TODO: Update this because it's outdated and definitely won't run.
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args()
......@@ -57,75 +61,139 @@ def main():
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args()
model = load_ict_checkpoint(only_block_model=True, no_grad=True)
model.eval()
dataset = get_ict_dataset()
data_iter = iter(get_one_epoch_dataloader(dataset))
all_block_data = BlockData()
hashed_index = RandProjectionLSHIndex(embed_size=128, num_buckets=32, whiten=True)
i = 1
total = 0
ran_once = False
while True:
with torch.no_grad():
try:
query_tokens, query_pad_mask, \
block_tokens, block_pad_mask, block_index_data = get_batch(data_iter)
except:
break
block_index_data = detach(block_index_data)
block_indices = block_index_data[:, 3]
block_meta = block_index_data[:, :3]
block_logits = detach(model(None, None, block_tokens, block_pad_mask, only_block=True))
all_block_data.add_block_data(block_indices, block_logits, block_meta)
total += block_indices.size
i += 1
if i % 20 == 0:
print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
if args.debug:
model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=ran_once)
model.eval()
dataset = get_ict_dataset()
data_iter = iter(get_one_epoch_dataloader(dataset))
all_block_data = BlockData()
hashed_index = RandProjectionLSHIndex(embed_size=128, num_buckets=32, whiten=True)
i = 1
total = 0
while True:
with torch.no_grad():
try:
query_tokens, query_pad_mask, \
block_tokens, block_pad_mask, block_index_data = get_batch(data_iter)
except:
break
all_block_data.save_shard(args.rank)
torch.distributed.barrier()
del model
block_index_data = detach(block_index_data)
block_indices = block_index_data[:, 3]
block_meta = block_index_data[:, :3]
block_logits = detach(model(None, None, block_tokens, block_pad_mask, only_block=True))
all_block_data.add_block_data(block_indices, block_logits, block_meta)
total += block_indices.size
i += 1
if i % 20 == 0:
print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
if args.debug:
break
all_block_data.save_shard(args.rank)
torch.distributed.barrier()
del model
if args.rank == 0:
all_block_data.consolidate_shards_and_save()
hashed_index.hash_whitened_block_embeds(all_block_data)
hashed_index.save_to_file()
else:
all_block_data.clear()
ran_once = True
set_index_com_file_ready()
torch.distributed.barrier()
while not check_model_com_file_ready():
time.sleep(5)
set_model_com_file_not_ready()
INDEX_COM_FILE = 'ready.index'
MODEL_COM_FILE = 'ready.model'
def setup_index_com_file():
set_index_com_file_not_ready()
if args.rank == 0:
all_block_data.consolidate_shards_and_save()
hashed_index.hash_whitened_block_embeds(all_block_data)
hashed_index.save_to_file()
else:
all_block_data.clear()
def set_index_com_file_not_ready():
with open(INDEX_COM_FILE, 'w') as com_file:
com_file.write('0')
def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False):
def set_index_com_file_ready():
with open(INDEX_COM_FILE, 'w') as com_file:
com_file.write('1')
def check_index_com_file_ready():
if os.path.exists(INDEX_COM_FILE):
with open(INDEX_COM_FILE, 'r') as com_file:
return bool(com_file.readline())
return False
def setup_model_com_file():
set_model_com_file_not_ready()
def set_model_com_file_not_ready():
with open(MODEL_COM_FILE, 'w') as com_file:
com_file.write('0')
def set_model_com_file_ready():
with open(MODEL_COM_FILE, 'w') as com_file:
com_file.write('1')
def check_model_com_file_ready():
if os.path.exists(MODEL_COM_FILE):
with open(MODEL_COM_FILE, 'r') as com_file:
return bool(com_file.readline())
return False
def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False):
args = get_args()
model = get_model(lambda: model_provider(only_query_model, only_block_model))
load_path = args.load if from_realm_chkpt else args.ict_load
if isinstance(model, torchDDP):
model = model.module
tracker_filename = get_checkpoint_tracker_filename(args.ict_load)
tracker_filename = get_checkpoint_tracker_filename(load_path)
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
assert iteration > 0
checkpoint_name = get_checkpoint_name(args.ict_load, iteration, False)
checkpoint_name = get_checkpoint_name(load_path, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu')
ict_state_dict = state_dict['model']
if from_realm_chkpt:
ict_state_dict = ict_state_dict['retriever']['ict_model']
if only_query_model:
state_dict['model'].pop('context_model')
ict_state_dict.pop('context_model')
if only_block_model:
state_dict['model'].pop('question_model')
ict_state_dict.pop('question_model')
if no_grad:
with torch.no_grad():
model.load_state_dict(state_dict['model'])
model.load_state_dict(ict_state_dict)
else:
model.load_state_dict(state_dict['model'])
model.load_state_dict(ict_state_dict)
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
......
......@@ -86,7 +86,8 @@ class FaissMIPSIndex(object):
self.m = 5
self.u = 0.99
self.max_norm = None
self.block_mips_index = self.get_block_index()
self.block_mips_index = None
self._set_block_index()
@classmethod
def load_from_file(cls, fname):
......@@ -101,7 +102,7 @@ class FaissMIPSIndex(object):
return new_index
def get_block_index(self):
def _set_block_index(self):
import faiss
INDEX_TYPES = ['flat_l2', 'flat_ip']
if self.index_type not in INDEX_TYPES:
......@@ -109,10 +110,13 @@ class FaissMIPSIndex(object):
if self.index_type == 'flat_l2':
index = faiss.IndexFlatL2(self.embed_size + 2 * self.m)
return faiss.IndexIDMap(index)
self.block_mips_index = faiss.IndexIDMap(index)
elif self.index_type == 'flat_ip':
index = faiss.IndexFlatIP(self.embed_size)
return faiss.IndexIDMap(index)
self.block_mips_index = faiss.IndexIDMap(index)
def reset_index(self):
self._set_block_index()
def add_block_embed_data(self, all_block_data, clear_block_data=False):
"""Add the embedding of each block to the underlying FAISS index"""
......
......@@ -4,7 +4,7 @@ import torch.nn.functional as F
from megatron import get_args
from megatron.checkpointing import load_checkpoint
from megatron.data.realm_index import detach
from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex
from megatron.model import BertModel
from megatron.model.utils import get_linear_layer, init_method_normal
from megatron.module import MegatronModule
......@@ -161,6 +161,12 @@ class REALMRetriever(MegatronModule):
self.top_k = top_k
self._ict_key = 'ict_model'
def reload_index(self):
args = get_args()
self.block_data = BlockData.load_from_file(args.block_data_path)
self.hashed_index.reset_index()
self.hashed_index.add_block_embed_data(self.block_data)
def retrieve_evidence_blocks_text(self, query_text):
"""Get the top k evidence blocks for query_text in text form"""
print("-" * 100)
......@@ -256,7 +262,6 @@ class ICTBertModel(MegatronModule):
if only_block:
return self.embed_block(block_tokens, block_attention_mask)
query_logits = self.embed_query(query_tokens, query_attention_mask)
block_logits = self.embed_block(block_tokens, block_attention_mask)
......
......@@ -39,6 +39,7 @@ from megatron.model import get_params_for_weight_decay_optimization
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import make_data_loader
from megatron.utils import report_memory
from indexer import check_index_com_file_ready, set_index_com_file_not_ready, set_model_com_file_ready
def pretrain(train_valid_test_dataset_provider, model_provider,
......@@ -363,6 +364,15 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
timers('interval time').start()
report_memory_flag = True
while iteration < args.train_iters:
if hasattr(model, 'retriever'):
new_index_ready = check_index_com_file_ready()
if new_index_ready:
torch.distributed.barrier()
model.retriever.reload_index()
set_index_com_file_not_ready()
save_checkpoint(iteration, model, optimizer, lr_scheduler)
set_model_com_file_ready()
loss_dict, skipped_iter = train_step(forward_step_func,
train_data_iterator,
model,
......
......@@ -18,7 +18,7 @@
import torch
import torch.nn.functional as F
from hashed_index import load_ict_checkpoint, get_ict_dataset
from indexer import load_ict_checkpoint, get_ict_dataset
from megatron.data.realm_index import BlockData, RandProjectionLSHIndex, FaissMIPSIndex
from megatron import get_args
from megatron import get_timers
......@@ -41,7 +41,7 @@ def model_provider():
ict_dataset = get_ict_dataset(use_titles=False)
all_block_data = BlockData.load_from_file(args.block_data_path)
# hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path)
hashed_index = FaissMIPSIndex(index_type='flat_l2', embed_size=128)
hashed_index = FaissMIPSIndex(index_type='flat_ip', embed_size=128)
hashed_index.add_block_embed_data(all_block_data)
# top_k + 1 because we may need to exclude trivial candidate
......@@ -102,7 +102,6 @@ def forward_step(data_iterator, model):
reduced_loss = reduce_losses([lm_loss])
torch.cuda.synchronize()
print(reduced_loss, flush=True)
return lm_loss, {'lm_loss': reduced_loss[0]}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment