Commit 612f438a authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

evaluation code ongoing

parent ebc95c35
......@@ -371,7 +371,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
return iteration
def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, from_realm_chkpt=False):
def load_ict_checkpoint(model, only_query_model=False, only_context_model=False, from_realm_chkpt=False):
"""selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints"""
args = get_args()
......@@ -393,14 +393,16 @@ def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, f
state_dict = torch.load(checkpoint_name, map_location='cpu')
ict_state_dict = state_dict['model']
print(ict_state_dict)
sys.exit()
if from_realm_chkpt and mpu.get_data_parallel_rank() == 0:
print(" loading ICT state dict from REALM", flush=True)
ict_state_dict = ict_state_dict['retriever']['ict_model']
if only_query_model:
ict_state_dict.pop('context_model')
if only_block_model:
ict_state_dict.pop('question_model')
if only_context_model:
ict_state_dict.pop('query_model')
model.load_state_dict(ict_state_dict)
torch.distributed.barrier()
......
......@@ -9,6 +9,33 @@ from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_co
from megatron import get_args, get_tokenizer, print_rank_0, mpu
def get_one_epoch_dataloader(dataset, micro_batch_size=None):
"""Specifically one epoch to be used in an indexing job."""
args = get_args()
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
if micro_batch_size is None:
micro_batch_size = args.micro_batch_size
global_batch_size = micro_batch_size * world_size
num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset)
# importantly, drop_last must be False to get all the data.
assert False, 'DistributedBatchSampler deprecated, change the implementation'
from megatron.data.samplers import DistributedBatchSampler
batch_sampler = DistributedBatchSampler(sampler,
batch_size=global_batch_size,
drop_last=False,
rank=rank,
world_size=world_size)
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
def get_ict_batch(data_iterator):
# Items and their type.
keys = ['query_tokens', 'query_mask',
......
......@@ -14,28 +14,29 @@ def detach(tensor):
return tensor.detach().cpu().numpy()
class BlockData(object):
"""Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM"""
def __init__(self, block_data_path=None, load_from_path=True, rank=None):
class OpenRetreivalDataStore(object):
"""Serializable data structure for holding data for blocks -- embeddings
and necessary metadata for Retriever"""
def __init__(self, embedding_path=None, load_from_path=True, rank=None):
self.embed_data = dict()
self.meta_data = dict()
if block_data_path is None:
#self.meta_data = dict()
if embedding_path is None:
args = get_args()
block_data_path = args.block_data_path
embedding_path = args.embedding_path
rank = args.rank
self.block_data_path = block_data_path
self.embedding_path = embedding_path
self.rank = rank
if load_from_path:
self.load_from_file()
block_data_name = os.path.splitext(self.block_data_path)[0]
block_data_name = os.path.splitext(self.embedding_path)[0]
self.temp_dir_name = block_data_name + '_tmp'
def state(self):
return {
'embed_data': self.embed_data,
'meta_data': self.meta_data,
#'meta_data': self.meta_data,
}
def clear(self):
......@@ -50,26 +51,28 @@ class BlockData(object):
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Unpickling BlockData", flush=True)
state_dict = pickle.load(open(self.block_data_path, 'rb'))
state_dict = pickle.load(open(self.embedding_path, 'rb'))
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Finished unpickling BlockData\n", flush=True)
self.embed_data = state_dict['embed_data']
self.meta_data = state_dict['meta_data']
#self.meta_data = state_dict['meta_data']
def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False):
#def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False):
def add_block_data(self, row_id, block_embeds, allow_overwrite=False):
"""Add data for set of blocks
:param block_indices: 1D array of unique int ids for the blocks
:param row_id: 1D array of unique int ids for the blocks
:param block_embeds: 2D array of embeddings of the blocks
:param block_metas: 2D array of metadata for the blocks.
#:param block_metas: 2D array of metadata for the blocks.
In the case of REALM this will be [start_idx, end_idx, doc_idx]
"""
for idx, embed, meta in zip(block_indices, block_embeds, block_metas):
#for idx, embed, meta in zip(block_indices, block_embeds, block_metas):
for idx, embed in zip(row_id, block_embeds):
if not allow_overwrite and idx in self.embed_data:
raise ValueError("Unexpectedly tried to overwrite block data")
self.embed_data[idx] = np.float16(embed)
self.meta_data[idx] = meta
#self.meta_data[idx] = meta
def save_shard(self):
"""Save the block data that was created this in this process"""
......@@ -77,8 +80,8 @@ class BlockData(object):
os.makedirs(self.temp_dir_name, exist_ok=True)
# save the data for each shard
with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') as data_file:
pickle.dump(self.state(), data_file)
with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') as writer:
pickle.dump(self.state(), writer)
def merge_shards_and_save(self):
"""Combine all the shards made using self.save_shard()"""
......@@ -98,13 +101,13 @@ class BlockData(object):
# add the shard's data and check to make sure there is no overlap
self.embed_data.update(data['embed_data'])
self.meta_data.update(data['meta_data'])
#self.meta_data.update(data['meta_data'])
assert len(self.embed_data) == old_size + shard_size
assert seen_own_shard
# save the consolidated shards and remove temporary directory
with open(self.block_data_path, 'wb') as final_file:
with open(self.embedding_path, 'wb') as final_file:
pickle.dump(self.state(), final_file)
shutil.rmtree(self.temp_dir_name, ignore_errors=True)
......
import sys
import torch
import torch.distributed as dist
......@@ -5,10 +6,11 @@ from megatron import get_args
from megatron import mpu
from megatron.checkpointing import load_ict_checkpoint
from megatron.data.ict_dataset import get_ict_dataset
from megatron.data.realm_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, BlockData
from megatron.data.realm_dataset_utils import get_ict_batch
from megatron.model.realm_model import general_ict_model_provider
from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, OpenRetreivalDataStore
from megatron.data.biencoder_dataset_utils import get_ict_batch
from megatron.model.biencoder_model import biencoder_model_provider
#from megatron.model.realm_model import general_ict_model_provider
from megatron.training import get_model
......@@ -34,13 +36,16 @@ class IndexBuilder(object):
def load_attributes(self):
"""Load the necessary attributes: model, dataloader and empty BlockData"""
model = get_model(lambda: general_ict_model_provider(only_block_model=True))
self.model = load_ict_checkpoint(model, only_block_model=True, from_realm_chkpt=self.using_realm_chkpt)
model = get_model(lambda: biencoder_model_provider(only_context_model=True))
self.model = load_ict_checkpoint(model, only_context_model=True, from_realm_chkpt=self.using_realm_chkpt)
sys.exit()
self.model.eval()
self.dataset = get_ict_dataset()
self.dataloader = iter(get_one_epoch_dataloader(self.dataset, self.batch_size))
self.block_data = BlockData(load_from_path=False)
self.block_data = OpenRetreivalDataStore(load_from_path=False)
print("load_attributes is done", flush=True)
sys.exit()
def track_and_report_progress(self, batch_size):
"""Utility function for tracking progress"""
self.iteration += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment