"src/nni_manager/vscode:/vscode.git/clone" did not exist on "a587648999af697474fc48aff00b870736bfd618"
Commit eaa5d877 authored by Neel Kant's avatar Neel Kant
Browse files

Add additional assertion on Indexer to test correctness, and limit verbosity in other classes

parent c2a32e12
...@@ -7,6 +7,7 @@ import numpy as np ...@@ -7,6 +7,7 @@ import numpy as np
import torch import torch
from megatron import get_args from megatron import get_args
from megatron import mpu
def detach(tensor): def detach(tensor):
...@@ -47,9 +48,11 @@ class BlockData(object): ...@@ -47,9 +48,11 @@ class BlockData(object):
def load_from_file(self): def load_from_file(self):
"""Populate members from instance saved to file""" """Populate members from instance saved to file"""
print("\n> Unpickling BlockData", flush=True) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Unpickling BlockData", flush=True)
state_dict = pickle.load(open(self.block_data_path, 'rb')) state_dict = pickle.load(open(self.block_data_path, 'rb'))
print(">> Finished unpickling BlockData\n", flush=True) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Finished unpickling BlockData\n", flush=True)
self.embed_data = state_dict['embed_data'] self.embed_data = state_dict['embed_data']
self.meta_data = state_dict['meta_data'] self.meta_data = state_dict['meta_data']
...@@ -127,7 +130,8 @@ class FaissMIPSIndex(object): ...@@ -127,7 +130,8 @@ class FaissMIPSIndex(object):
except ImportError: except ImportError:
raise Exception("Error: Please install faiss to use FaissMIPSIndex") raise Exception("Error: Please install faiss to use FaissMIPSIndex")
print("\n> Building index", flush=True) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Building index", flush=True)
self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT) self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)
if self.use_gpu: if self.use_gpu:
...@@ -138,11 +142,13 @@ class FaissMIPSIndex(object): ...@@ -138,11 +142,13 @@ class FaissMIPSIndex(object):
config.useFloat16 = True config.useFloat16 = True
self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config) self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config)
print(">> Initialized index on GPU {}".format(self.block_mips_index.getDevice()), flush=True) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on GPU {}".format(self.block_mips_index.getDevice()), flush=True)
else: else:
# CPU index supports IDs so wrap with IDMap # CPU index supports IDs so wrap with IDMap
self.block_mips_index = faiss.IndexIDMap(self.block_mips_index) self.block_mips_index = faiss.IndexIDMap(self.block_mips_index)
print(">> Initialized index on CPU", flush=True) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on CPU", flush=True)
# if we were constructed with a BlockData, then automatically load it when the FAISS structure is built # if we were constructed with a BlockData, then automatically load it when the FAISS structure is built
if self.block_data is not None: if self.block_data is not None:
...@@ -156,7 +162,7 @@ class FaissMIPSIndex(object): ...@@ -156,7 +162,7 @@ class FaissMIPSIndex(object):
if self.block_data is not None: if self.block_data is not None:
block_data_path = self.block_data.block_data_path block_data_path = self.block_data.block_data_path
del self.block_data del self.block_data
self.block_data = BlockData.load_from_file(block_data_path) self.block_data = BlockData(block_data_path)
self._set_block_index() self._set_block_index()
...@@ -183,7 +189,8 @@ class FaissMIPSIndex(object): ...@@ -183,7 +189,8 @@ class FaissMIPSIndex(object):
else: else:
self.block_mips_index.add_with_ids(block_embeds_arr, block_indices_arr) self.block_mips_index.add_with_ids(block_embeds_arr, block_indices_arr)
print(">>> Finished adding block data to index", flush=True) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">>> Finished adding block data to index", flush=True)
def search_mips_index(self, query_embeds, top_k, reconstruct=True): def search_mips_index(self, query_embeds, top_k, reconstruct=True):
"""Get the top-k blocks by the index distance metric. """Get the top-k blocks by the index distance metric.
......
...@@ -37,7 +37,8 @@ class IndexBuilder(object): ...@@ -37,7 +37,8 @@ class IndexBuilder(object):
model = get_model(lambda: general_ict_model_provider(only_block_model=True)) model = get_model(lambda: general_ict_model_provider(only_block_model=True))
self.model = load_ict_checkpoint(model, only_block_model=True, from_realm_chkpt=self.using_realm_chkpt) self.model = load_ict_checkpoint(model, only_block_model=True, from_realm_chkpt=self.using_realm_chkpt)
self.model.eval() self.model.eval()
self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset(), self.batch_size)) self.dataset = get_ict_dataset()
self.dataloader = iter(get_one_epoch_dataloader(self.dataset, self.batch_size))
self.block_data = BlockData(load_from_path=False) self.block_data = BlockData(load_from_path=False)
def track_and_report_progress(self, batch_size): def track_and_report_progress(self, batch_size):
...@@ -58,7 +59,7 @@ class IndexBuilder(object): ...@@ -58,7 +59,7 @@ class IndexBuilder(object):
try: try:
# batch also has query_tokens and query_pad_data # batch also has query_tokens and query_pad_data
_, _, block_tokens, block_pad_mask, block_sample_data = get_ict_batch(self.dataloader) _, _, block_tokens, block_pad_mask, block_sample_data = get_ict_batch(self.dataloader)
except StopIteration: except (StopIteration, IndexError):
break break
unwrapped_model = self.model unwrapped_model = self.model
...@@ -85,6 +86,6 @@ class IndexBuilder(object): ...@@ -85,6 +86,6 @@ class IndexBuilder(object):
# rank 0 process builds the final copy # rank 0 process builds the final copy
if self.is_main_builder: if self.is_main_builder:
self.block_data.merge_shards_and_save() self.block_data.merge_shards_and_save()
# make sure that every single piece of data was embedded
assert len(self.block_data.embed_data) == len(self.dataset)
self.block_data.clear() self.block_data.clear()
...@@ -21,6 +21,7 @@ from .data import broadcast_data ...@@ -21,6 +21,7 @@ from .data import broadcast_data
from .grads import clip_grad_norm from .grads import clip_grad_norm
from .initialize import is_unitialized
from .initialize import destroy_model_parallel from .initialize import destroy_model_parallel
from .initialize import get_data_parallel_group from .initialize import get_data_parallel_group
from .initialize import get_data_parallel_rank from .initialize import get_data_parallel_rank
......
...@@ -31,6 +31,11 @@ _MPU_WORLD_SIZE = None ...@@ -31,6 +31,11 @@ _MPU_WORLD_SIZE = None
_MPU_RANK = None _MPU_RANK = None
def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return _DATA_PARALLEL_GROUP is None
def initialize_model_parallel(model_parallel_size_): def initialize_model_parallel(model_parallel_size_):
""" """
Initialize model data parallel groups. Initialize model data parallel groups.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment