Commit 15d0d55b authored by Neel Kant's avatar Neel Kant
Browse files

Add primitive filesystem-based IPC for indexer and trainer jobs

parent 0f5e2809
...@@ -67,7 +67,6 @@ def print_accuracy_stats(name, gold_indices, estimated_indices): ...@@ -67,7 +67,6 @@ def print_accuracy_stats(name, gold_indices, estimated_indices):
print('{:20s} First missing: {:4d} | All equal: {:4d} | Mixed: {:4d}'.format(name, *[results[s] for s in result_strs])) print('{:20s} First missing: {:4d} | All equal: {:4d} | Mixed: {:4d}'.format(name, *[results[s] for s in result_strs]))
def create_and_test_gold(d, k, embeds, queries): def create_and_test_gold(d, k, embeds, queries):
times = [time.time()] times = [time.time()]
gold_idx = index_factory(d, 'Flat') gold_idx = index_factory(d, 'Flat')
......
import os
import time
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
...@@ -15,6 +18,7 @@ from pretrain_bert_ict import get_batch, model_provider ...@@ -15,6 +18,7 @@ from pretrain_bert_ict import get_batch, model_provider
def test_retriever(): def test_retriever():
# TODO: Update this because it's outdated and definitely won't run.
initialize_megatron(extra_args_provider=None, initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args() args = get_args()
...@@ -57,75 +61,139 @@ def main(): ...@@ -57,75 +61,139 @@ def main():
initialize_megatron(extra_args_provider=None, initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args() args = get_args()
model = load_ict_checkpoint(only_block_model=True, no_grad=True) ran_once = False
model.eval()
dataset = get_ict_dataset()
data_iter = iter(get_one_epoch_dataloader(dataset))
all_block_data = BlockData()
hashed_index = RandProjectionLSHIndex(embed_size=128, num_buckets=32, whiten=True)
i = 1
total = 0
while True: while True:
with torch.no_grad(): model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=ran_once)
try: model.eval()
query_tokens, query_pad_mask, \ dataset = get_ict_dataset()
block_tokens, block_pad_mask, block_index_data = get_batch(data_iter) data_iter = iter(get_one_epoch_dataloader(dataset))
except: all_block_data = BlockData()
break hashed_index = RandProjectionLSHIndex(embed_size=128, num_buckets=32, whiten=True)
block_index_data = detach(block_index_data) i = 1
block_indices = block_index_data[:, 3] total = 0
block_meta = block_index_data[:, :3] while True:
with torch.no_grad():
block_logits = detach(model(None, None, block_tokens, block_pad_mask, only_block=True)) try:
all_block_data.add_block_data(block_indices, block_logits, block_meta) query_tokens, query_pad_mask, \
block_tokens, block_pad_mask, block_index_data = get_batch(data_iter)
total += block_indices.size except:
i += 1
if i % 20 == 0:
print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
if args.debug:
break break
all_block_data.save_shard(args.rank) block_index_data = detach(block_index_data)
torch.distributed.barrier() block_indices = block_index_data[:, 3]
del model block_meta = block_index_data[:, :3]
block_logits = detach(model(None, None, block_tokens, block_pad_mask, only_block=True))
all_block_data.add_block_data(block_indices, block_logits, block_meta)
total += block_indices.size
i += 1
if i % 20 == 0:
print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
if args.debug:
break
all_block_data.save_shard(args.rank)
torch.distributed.barrier()
del model
if args.rank == 0:
all_block_data.consolidate_shards_and_save()
hashed_index.hash_whitened_block_embeds(all_block_data)
hashed_index.save_to_file()
else:
all_block_data.clear()
ran_once = True
set_index_com_file_ready()
torch.distributed.barrier()
while not check_model_com_file_ready():
time.sleep(5)
set_model_com_file_not_ready()
INDEX_COM_FILE = 'ready.index'
MODEL_COM_FILE = 'ready.model'
def setup_index_com_file():
set_index_com_file_not_ready()
if args.rank == 0:
all_block_data.consolidate_shards_and_save()
hashed_index.hash_whitened_block_embeds(all_block_data)
hashed_index.save_to_file()
else:
all_block_data.clear()
def set_index_com_file_not_ready():
with open(INDEX_COM_FILE, 'w') as com_file:
com_file.write('0')
def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False):
def set_index_com_file_ready():
with open(INDEX_COM_FILE, 'w') as com_file:
com_file.write('1')
def check_index_com_file_ready():
if os.path.exists(INDEX_COM_FILE):
with open(INDEX_COM_FILE, 'r') as com_file:
return bool(com_file.readline())
return False
def setup_model_com_file():
set_model_com_file_not_ready()
def set_model_com_file_not_ready():
with open(MODEL_COM_FILE, 'w') as com_file:
com_file.write('0')
def set_model_com_file_ready():
with open(MODEL_COM_FILE, 'w') as com_file:
com_file.write('1')
def check_model_com_file_ready():
if os.path.exists(MODEL_COM_FILE):
with open(MODEL_COM_FILE, 'r') as com_file:
return bool(com_file.readline())
return False
def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False):
args = get_args() args = get_args()
model = get_model(lambda: model_provider(only_query_model, only_block_model)) model = get_model(lambda: model_provider(only_query_model, only_block_model))
load_path = args.load if from_realm_chkpt else args.ict_load
if isinstance(model, torchDDP): if isinstance(model, torchDDP):
model = model.module model = model.module
tracker_filename = get_checkpoint_tracker_filename(args.ict_load) tracker_filename = get_checkpoint_tracker_filename(load_path)
with open(tracker_filename, 'r') as f: with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip()) iteration = int(f.read().strip())
assert iteration > 0 assert iteration > 0
checkpoint_name = get_checkpoint_name(args.ict_load, iteration, False) checkpoint_name = get_checkpoint_name(load_path, iteration, False)
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format( print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name)) torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu') state_dict = torch.load(checkpoint_name, map_location='cpu')
ict_state_dict = state_dict['model']
if from_realm_chkpt:
ict_state_dict = ict_state_dict['retriever']['ict_model']
if only_query_model: if only_query_model:
state_dict['model'].pop('context_model') ict_state_dict.pop('context_model')
if only_block_model: if only_block_model:
state_dict['model'].pop('question_model') ict_state_dict.pop('question_model')
if no_grad: if no_grad:
with torch.no_grad(): with torch.no_grad():
model.load_state_dict(state_dict['model']) model.load_state_dict(ict_state_dict)
else: else:
model.load_state_dict(state_dict['model']) model.load_state_dict(ict_state_dict)
torch.distributed.barrier() torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
......
...@@ -86,7 +86,8 @@ class FaissMIPSIndex(object): ...@@ -86,7 +86,8 @@ class FaissMIPSIndex(object):
self.m = 5 self.m = 5
self.u = 0.99 self.u = 0.99
self.max_norm = None self.max_norm = None
self.block_mips_index = self.get_block_index() self.block_mips_index = None
self._set_block_index()
@classmethod @classmethod
def load_from_file(cls, fname): def load_from_file(cls, fname):
...@@ -101,7 +102,7 @@ class FaissMIPSIndex(object): ...@@ -101,7 +102,7 @@ class FaissMIPSIndex(object):
return new_index return new_index
def get_block_index(self): def _set_block_index(self):
import faiss import faiss
INDEX_TYPES = ['flat_l2', 'flat_ip'] INDEX_TYPES = ['flat_l2', 'flat_ip']
if self.index_type not in INDEX_TYPES: if self.index_type not in INDEX_TYPES:
...@@ -109,10 +110,13 @@ class FaissMIPSIndex(object): ...@@ -109,10 +110,13 @@ class FaissMIPSIndex(object):
if self.index_type == 'flat_l2': if self.index_type == 'flat_l2':
index = faiss.IndexFlatL2(self.embed_size + 2 * self.m) index = faiss.IndexFlatL2(self.embed_size + 2 * self.m)
return faiss.IndexIDMap(index) self.block_mips_index = faiss.IndexIDMap(index)
elif self.index_type == 'flat_ip': elif self.index_type == 'flat_ip':
index = faiss.IndexFlatIP(self.embed_size) index = faiss.IndexFlatIP(self.embed_size)
return faiss.IndexIDMap(index) self.block_mips_index = faiss.IndexIDMap(index)
def reset_index(self):
self._set_block_index()
def add_block_embed_data(self, all_block_data, clear_block_data=False): def add_block_embed_data(self, all_block_data, clear_block_data=False):
"""Add the embedding of each block to the underlying FAISS index""" """Add the embedding of each block to the underlying FAISS index"""
......
...@@ -4,7 +4,7 @@ import torch.nn.functional as F ...@@ -4,7 +4,7 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.data.realm_index import detach from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex
from megatron.model import BertModel from megatron.model import BertModel
from megatron.model.utils import get_linear_layer, init_method_normal from megatron.model.utils import get_linear_layer, init_method_normal
from megatron.module import MegatronModule from megatron.module import MegatronModule
...@@ -161,6 +161,12 @@ class REALMRetriever(MegatronModule): ...@@ -161,6 +161,12 @@ class REALMRetriever(MegatronModule):
self.top_k = top_k self.top_k = top_k
self._ict_key = 'ict_model' self._ict_key = 'ict_model'
def reload_index(self):
args = get_args()
self.block_data = BlockData.load_from_file(args.block_data_path)
self.hashed_index.reset_index()
self.hashed_index.add_block_embed_data(self.block_data)
def retrieve_evidence_blocks_text(self, query_text): def retrieve_evidence_blocks_text(self, query_text):
"""Get the top k evidence blocks for query_text in text form""" """Get the top k evidence blocks for query_text in text form"""
print("-" * 100) print("-" * 100)
...@@ -256,7 +262,6 @@ class ICTBertModel(MegatronModule): ...@@ -256,7 +262,6 @@ class ICTBertModel(MegatronModule):
if only_block: if only_block:
return self.embed_block(block_tokens, block_attention_mask) return self.embed_block(block_tokens, block_attention_mask)
query_logits = self.embed_query(query_tokens, query_attention_mask) query_logits = self.embed_query(query_tokens, query_attention_mask)
block_logits = self.embed_block(block_tokens, block_attention_mask) block_logits = self.embed_block(block_tokens, block_attention_mask)
......
...@@ -39,6 +39,7 @@ from megatron.model import get_params_for_weight_decay_optimization ...@@ -39,6 +39,7 @@ from megatron.model import get_params_for_weight_decay_optimization
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import make_data_loader from megatron.utils import make_data_loader
from megatron.utils import report_memory from megatron.utils import report_memory
from indexer import check_index_com_file_ready, set_index_com_file_not_ready, set_model_com_file_ready
def pretrain(train_valid_test_dataset_provider, model_provider, def pretrain(train_valid_test_dataset_provider, model_provider,
...@@ -363,6 +364,15 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -363,6 +364,15 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
timers('interval time').start() timers('interval time').start()
report_memory_flag = True report_memory_flag = True
while iteration < args.train_iters: while iteration < args.train_iters:
if hasattr(model, 'retriever'):
new_index_ready = check_index_com_file_ready()
if new_index_ready:
torch.distributed.barrier()
model.retriever.reload_index()
set_index_com_file_not_ready()
save_checkpoint(iteration, model, optimizer, lr_scheduler)
set_model_com_file_ready()
loss_dict, skipped_iter = train_step(forward_step_func, loss_dict, skipped_iter = train_step(forward_step_func,
train_data_iterator, train_data_iterator,
model, model,
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from hashed_index import load_ict_checkpoint, get_ict_dataset from indexer import load_ict_checkpoint, get_ict_dataset
from megatron.data.realm_index import BlockData, RandProjectionLSHIndex, FaissMIPSIndex from megatron.data.realm_index import BlockData, RandProjectionLSHIndex, FaissMIPSIndex
from megatron import get_args from megatron import get_args
from megatron import get_timers from megatron import get_timers
...@@ -41,7 +41,7 @@ def model_provider(): ...@@ -41,7 +41,7 @@ def model_provider():
ict_dataset = get_ict_dataset(use_titles=False) ict_dataset = get_ict_dataset(use_titles=False)
all_block_data = BlockData.load_from_file(args.block_data_path) all_block_data = BlockData.load_from_file(args.block_data_path)
# hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path) # hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path)
hashed_index = FaissMIPSIndex(index_type='flat_l2', embed_size=128) hashed_index = FaissMIPSIndex(index_type='flat_ip', embed_size=128)
hashed_index.add_block_embed_data(all_block_data) hashed_index.add_block_embed_data(all_block_data)
# top_k + 1 because we may need to exclude trivial candidate # top_k + 1 because we may need to exclude trivial candidate
...@@ -102,7 +102,6 @@ def forward_step(data_iterator, model): ...@@ -102,7 +102,6 @@ def forward_step(data_iterator, model):
reduced_loss = reduce_losses([lm_loss]) reduced_loss = reduce_losses([lm_loss])
torch.cuda.synchronize() torch.cuda.synchronize()
print(reduced_loss, flush=True)
return lm_loss, {'lm_loss': reduced_loss[0]} return lm_loss, {'lm_loss': reduced_loss[0]}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment