Commit eaa5d877 authored by Neel Kant's avatar Neel Kant
Browse files

Add additional assertion on Indexer to test correctness, and limit verbosity in other classes

parent c2a32e12
...@@ -7,6 +7,7 @@ import numpy as np ...@@ -7,6 +7,7 @@ import numpy as np
import torch import torch
from megatron import get_args from megatron import get_args
from megatron import mpu
def detach(tensor): def detach(tensor):
...@@ -47,9 +48,11 @@ class BlockData(object): ...@@ -47,9 +48,11 @@ class BlockData(object):
def load_from_file(self): def load_from_file(self):
"""Populate members from instance saved to file""" """Populate members from instance saved to file"""
print("\n> Unpickling BlockData", flush=True) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Unpickling BlockData", flush=True)
state_dict = pickle.load(open(self.block_data_path, 'rb')) state_dict = pickle.load(open(self.block_data_path, 'rb'))
print(">> Finished unpickling BlockData\n", flush=True) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Finished unpickling BlockData\n", flush=True)
self.embed_data = state_dict['embed_data'] self.embed_data = state_dict['embed_data']
self.meta_data = state_dict['meta_data'] self.meta_data = state_dict['meta_data']
...@@ -127,7 +130,8 @@ class FaissMIPSIndex(object): ...@@ -127,7 +130,8 @@ class FaissMIPSIndex(object):
except ImportError: except ImportError:
raise Exception("Error: Please install faiss to use FaissMIPSIndex") raise Exception("Error: Please install faiss to use FaissMIPSIndex")
print("\n> Building index", flush=True) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Building index", flush=True)
self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT) self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)
if self.use_gpu: if self.use_gpu:
...@@ -138,11 +142,13 @@ class FaissMIPSIndex(object): ...@@ -138,11 +142,13 @@ class FaissMIPSIndex(object):
config.useFloat16 = True config.useFloat16 = True
self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config) self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config)
print(">> Initialized index on GPU {}".format(self.block_mips_index.getDevice()), flush=True) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on GPU {}".format(self.block_mips_index.getDevice()), flush=True)
else: else:
# CPU index supports IDs so wrap with IDMap # CPU index supports IDs so wrap with IDMap
self.block_mips_index = faiss.IndexIDMap(self.block_mips_index) self.block_mips_index = faiss.IndexIDMap(self.block_mips_index)
print(">> Initialized index on CPU", flush=True) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on CPU", flush=True)
# if we were constructed with a BlockData, then automatically load it when the FAISS structure is built # if we were constructed with a BlockData, then automatically load it when the FAISS structure is built
if self.block_data is not None: if self.block_data is not None:
...@@ -156,7 +162,7 @@ class FaissMIPSIndex(object): ...@@ -156,7 +162,7 @@ class FaissMIPSIndex(object):
if self.block_data is not None: if self.block_data is not None:
block_data_path = self.block_data.block_data_path block_data_path = self.block_data.block_data_path
del self.block_data del self.block_data
self.block_data = BlockData.load_from_file(block_data_path) self.block_data = BlockData(block_data_path)
self._set_block_index() self._set_block_index()
...@@ -183,7 +189,8 @@ class FaissMIPSIndex(object): ...@@ -183,7 +189,8 @@ class FaissMIPSIndex(object):
else: else:
self.block_mips_index.add_with_ids(block_embeds_arr, block_indices_arr) self.block_mips_index.add_with_ids(block_embeds_arr, block_indices_arr)
print(">>> Finished adding block data to index", flush=True) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">>> Finished adding block data to index", flush=True)
def search_mips_index(self, query_embeds, top_k, reconstruct=True): def search_mips_index(self, query_embeds, top_k, reconstruct=True):
"""Get the top-k blocks by the index distance metric. """Get the top-k blocks by the index distance metric.
......
...@@ -37,7 +37,8 @@ class IndexBuilder(object): ...@@ -37,7 +37,8 @@ class IndexBuilder(object):
model = get_model(lambda: general_ict_model_provider(only_block_model=True)) model = get_model(lambda: general_ict_model_provider(only_block_model=True))
self.model = load_ict_checkpoint(model, only_block_model=True, from_realm_chkpt=self.using_realm_chkpt) self.model = load_ict_checkpoint(model, only_block_model=True, from_realm_chkpt=self.using_realm_chkpt)
self.model.eval() self.model.eval()
self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset(), self.batch_size)) self.dataset = get_ict_dataset()
self.dataloader = iter(get_one_epoch_dataloader(self.dataset, self.batch_size))
self.block_data = BlockData(load_from_path=False) self.block_data = BlockData(load_from_path=False)
def track_and_report_progress(self, batch_size): def track_and_report_progress(self, batch_size):
...@@ -58,7 +59,7 @@ class IndexBuilder(object): ...@@ -58,7 +59,7 @@ class IndexBuilder(object):
try: try:
# batch also has query_tokens and query_pad_data # batch also has query_tokens and query_pad_data
_, _, block_tokens, block_pad_mask, block_sample_data = get_ict_batch(self.dataloader) _, _, block_tokens, block_pad_mask, block_sample_data = get_ict_batch(self.dataloader)
except StopIteration: except (StopIteration, IndexError):
break break
unwrapped_model = self.model unwrapped_model = self.model
...@@ -85,6 +86,6 @@ class IndexBuilder(object): ...@@ -85,6 +86,6 @@ class IndexBuilder(object):
# rank 0 process builds the final copy # rank 0 process builds the final copy
if self.is_main_builder: if self.is_main_builder:
self.block_data.merge_shards_and_save() self.block_data.merge_shards_and_save()
# make sure that every single piece of data was embedded
assert len(self.block_data.embed_data) == len(self.dataset)
self.block_data.clear() self.block_data.clear()
...@@ -21,6 +21,7 @@ from .data import broadcast_data ...@@ -21,6 +21,7 @@ from .data import broadcast_data
from .grads import clip_grad_norm from .grads import clip_grad_norm
from .initialize import is_unitialized
from .initialize import destroy_model_parallel from .initialize import destroy_model_parallel
from .initialize import get_data_parallel_group from .initialize import get_data_parallel_group
from .initialize import get_data_parallel_rank from .initialize import get_data_parallel_rank
......
...@@ -31,6 +31,11 @@ _MPU_WORLD_SIZE = None ...@@ -31,6 +31,11 @@ _MPU_WORLD_SIZE = None
_MPU_RANK = None _MPU_RANK = None
def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return _DATA_PARALLEL_GROUP is None
def initialize_model_parallel(model_parallel_size_): def initialize_model_parallel(model_parallel_size_):
""" """
Initialize model data parallel groups. Initialize model data parallel groups.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment