Commit d9d4ce70 authored by Neel Kant's avatar Neel Kant
Browse files

adding realm_index.py

parent b0a3c636
from collections import defaultdict
import itertools
import os
import pickle
import shutil
import faiss
import numpy as np
import torch
from megatron import get_args, mpu
def detach(tensor):
return tensor.detach().cpu().numpy()
class BlockData(object):
"""Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM"""
def __init__(self, block_data_path=None, rank=None):
self.embed_data = dict()
self.meta_data = dict()
if block_data_path is None:
args = get_args()
block_data_path = args.block_data_path
rank = args.rank
self.block_data_path = block_data_path
self.rank = rank
block_data_name = os.path.splitext(self.block_data_path)[0]
self.temp_dir_name = block_data_name + '_tmp'
def state(self):
return {
'embed_data': self.embed_data,
'meta_data': self.meta_data,
}
def clear(self):
"""Clear the embedding data structures to save memory.
The metadata ends up getting used, and is also much smaller in dimensionality
so it isn't really worth clearing.
"""
self.embed_data = dict()
@classmethod
def load_from_file(cls, fname):
print("\n> Unpickling BlockData", flush=True)
state_dict = pickle.load(open(fname, 'rb'))
print(">> Finished unpickling BlockData\n", flush=True)
new_index = cls()
new_index.embed_data = state_dict['embed_data']
new_index.meta_data = state_dict['meta_data']
return new_index
def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False):
for idx, embed, meta in zip(block_indices, block_embeds, block_metas):
if not allow_overwrite and idx in self.embed_data:
raise ValueError("Unexpectedly tried to overwrite block data")
self.embed_data[idx] = np.float16(embed)
self.meta_data[idx] = meta
def save_shard(self):
if not os.path.isdir(self.temp_dir_name):
os.makedirs(self.temp_dir_name, exist_ok=True)
# save the data for each shard
with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') as data_file:
pickle.dump(self.state(), data_file)
def merge_shards_and_save(self):
"""Combine all the shards made using self.save_shard()"""
shard_names = os.listdir(self.temp_dir_name)
seen_own_shard = False
for fname in os.listdir(self.temp_dir_name):
shard_rank = int(os.path.splitext(fname)[0])
if shard_rank == self.rank:
seen_own_shard = True
continue
with open('{}/{}'.format(self.temp_dir_name, fname), 'rb') as f:
data = pickle.load(f)
old_size = len(self.embed_data)
shard_size = len(data['embed_data'])
# add the shard's data and check to make sure there is no overlap
self.embed_data.update(data['embed_data'])
self.meta_data.update(data['meta_data'])
assert len(self.embed_data) == old_size + shard_size
assert seen_own_shard
# save the consolidated shards and remove temporary directory
with open(self.block_data_path, 'wb') as final_file:
pickle.dump(self.state(), final_file)
shutil.rmtree(self.temp_dir_name, ignore_errors=True)
print("Finished merging {} shards for a total of {} embeds".format(
len(shard_names), len(self.embed_data)), flush=True)
class FaissMIPSIndex(object):
"""Wrapper object for a BlockData which similarity search via FAISS under the hood"""
def __init__(self, index_type, embed_size, use_gpu=False):
self.index_type = index_type
self.embed_size = embed_size
self.use_gpu = use_gpu
self.id_map = dict()
self.block_mips_index = None
self._set_block_index()
def _set_block_index(self):
INDEX_TYPES = ['flat_ip']
if self.index_type not in INDEX_TYPES:
raise ValueError("Invalid index type specified")
print("\n> Building index", flush=True)
self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)
if self.use_gpu:
# create resources and config for GpuIndex
res = faiss.StandardGpuResources()
config = faiss.GpuIndexFlatConfig()
config.device = torch.cuda.current_device()
config.useFloat16 = True
self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config)
print(">>> Finished building index on GPU {}\n".format(self.block_mips_index.getDevice()), flush=True)
else:
# CPU index supports IDs so wrap with IDMap
self.block_mips_index = faiss.IndexIDMap(self.block_mips_index)
print(">> Finished building index\n", flush=True)
def reset_index(self):
"""Delete existing index and create anew"""
del self.block_mips_index
self._set_block_index()
def add_block_embed_data(self, all_block_data):
"""Add the embedding of each block to the underlying FAISS index"""
block_indices, block_embeds = zip(*all_block_data.embed_data.items())
if self.use_gpu:
for i, idx in enumerate(block_indices):
self.id_map[i] = idx
all_block_data.clear()
if self.use_gpu:
self.block_mips_index.add(np.float32(np.array(block_embeds)))
else:
self.block_mips_index.add_with_ids(np.float32(np.array(block_embeds)), np.array(block_indices))
def search_mips_index(self, query_embeds, top_k, reconstruct=True):
"""Get the top-k blocks by the index distance metric.
:param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks
if False: return [num_queries x k] array of distances, and another for indices
"""
query_embeds = np.float32(detach(query_embeds))
with torch.no_grad():
if reconstruct:
top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k)
return top_k_block_embeds
else:
distances, block_indices = self.block_mips_index.search(query_embeds, top_k)
if self.use_gpu:
fresh_indices = np.zeros(block_indices.shape)
for i, j in itertools.product(block_indices.shape):
fresh_indices[i, j] = self.id_map[block_indices[i, j]]
block_indices = fresh_indices
return distances, block_indices
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment