Commit a2e64ad5 authored by Neel Kant's avatar Neel Kant
Browse files

Move REALM to use FAISS

parent 29825734
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
import pickle import pickle
import shutil import shutil
import faiss
import numpy as np import numpy as np
import torch import torch
...@@ -121,7 +122,7 @@ class FaissMIPSIndex(object): ...@@ -121,7 +122,7 @@ class FaissMIPSIndex(object):
if self.index_type == 'flat_l2': if self.index_type == 'flat_l2':
block_embeds = self.alsh_block_preprocess_fn(block_embeds) block_embeds = self.alsh_block_preprocess_fn(block_embeds)
self.block_mips_index.add_with_ids(block_embeds, block_indices) self.block_mips_index.add_with_ids(np.array(block_embeds), np.array(block_indices))
def search_mips_index(self, query_embeds, top_k, reconstruct=True): def search_mips_index(self, query_embeds, top_k, reconstruct=True):
"""Get the top-k blocks by the index distance metric. """Get the top-k blocks by the index distance metric.
...@@ -216,7 +217,7 @@ class RandProjectionLSHIndex(object): ...@@ -216,7 +217,7 @@ class RandProjectionLSHIndex(object):
def hash_embeds(self, embeds, write_block_data=None): def hash_embeds(self, embeds, write_block_data=None):
"""Hash a tensor of embeddings using a random projection matrix""" """Hash a tensor of embeddings using a random projection matrix"""
embed_scores_pos = torch.matmul(embeds, torch.cuda.HalfTensor(self.hash_matrix)) embed_scores_pos = torch.matmul(embeds, torch.cuda.FloatTensor(self.hash_matrix).type(embeds.dtype))
embed_scores = torch.cat((embed_scores_pos, -embed_scores_pos), axis=1) embed_scores = torch.cat((embed_scores_pos, -embed_scores_pos), axis=1)
embed_hashes = detach(torch.argmax(embed_scores, axis=1)) embed_hashes = detach(torch.argmax(embed_scores, axis=1))
......
...@@ -22,6 +22,7 @@ import torch ...@@ -22,6 +22,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron.data.realm_index import detach
from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm from megatron.model.transformer import LayerNorm
...@@ -86,7 +87,7 @@ class BertLMHead(MegatronModule): ...@@ -86,7 +87,7 @@ class BertLMHead(MegatronModule):
super(BertLMHead, self).__init__() super(BertLMHead, self).__init__()
args = get_args() args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.model_parallel = True self.bias.model_parallel = True
self.bias.partition_dim = 0 self.bias.partition_dim = 0
...@@ -247,11 +248,11 @@ class REALMBertModel(MegatronModule): ...@@ -247,11 +248,11 @@ class REALMBertModel(MegatronModule):
top5_block_attention_mask = torch.cuda.LongTensor(top5_block_attention_mask).reshape(-1, seq_length) top5_block_attention_mask = torch.cuda.LongTensor(top5_block_attention_mask).reshape(-1, seq_length)
# [batch_size x 5 x embed_size] # [batch_size x 5 x embed_size]
fresh_block_logits = self.retriever.ict_model(None, None, top5_block_tokens, top5_block_attention_mask, only_block=True).reshape(batch_size, 5, -1) true_model = self.retriever.ict_model.module.module
# fresh_block_logits.register_hook(lambda x: print("fresh block: ", x.shape, flush=True)) fresh_block_logits = true_model.embed_block(top5_block_tokens, top5_block_attention_mask).reshape(batch_size, 5, -1)
# [batch_size x embed_size x 1] # [batch_size x embed_size x 1]
query_logits = self.retriever.ict_model(tokens, attention_mask, None, None, only_query=True).unsqueeze(2) query_logits = true_model.embed_query(tokens, attention_mask).unsqueeze(2)
# [batch_size x 5] # [batch_size x 5]
...@@ -310,36 +311,21 @@ class REALMRetriever(MegatronModule): ...@@ -310,36 +311,21 @@ class REALMRetriever(MegatronModule):
def retrieve_evidence_blocks(self, query_tokens, query_pad_mask): def retrieve_evidence_blocks(self, query_tokens, query_pad_mask):
"""Embed blocks to be used in a forward pass""" """Embed blocks to be used in a forward pass"""
query_embeds = self.ict_model(query_tokens, query_pad_mask, None, None, only_query=True) with torch.no_grad():
query_hashes = self.hashed_index.hash_embeds(query_embeds) true_model = self.ict_model.module.module
query_embeds = detach(true_model.embed_query(query_tokens, query_pad_mask))
block_buckets = [self.hashed_index.get_block_bucket(hash) for hash in query_hashes] _, block_indices = self.hashed_index.search_mips_index(query_embeds, top_k=self.top_k, reconstruct=False)
for j, bucket in enumerate(block_buckets):
if len(bucket) < 5:
for i in range(len(block_buckets)):
if len(block_buckets[i]) > 5:
block_buckets[j] = block_buckets[i].copy()
# [batch_size x max_bucket_population x embed_size]
block_embeds = [torch.cuda.FloatTensor(np.array([self.block_data.embed_data[idx]
for idx in bucket])) for bucket in block_buckets]
all_top5_tokens, all_top5_pad_masks = [], [] all_top5_tokens, all_top5_pad_masks = [], []
for query_embed, embed_tensor, bucket in zip(query_embeds, block_embeds, block_buckets): for indices in block_indices:
retrieval_scores = query_embed.matmul(torch.transpose(embed_tensor.reshape(-1, query_embed.size()[0]), 0, 1)) # [k x meta_dim]
print(retrieval_scores.shape, flush=True) top5_metas = np.array([self.block_data.meta_data[idx] for idx in indices])
top5_vals, top5_indices = torch.topk(retrieval_scores, k=5, sorted=True) top5_block_data = [self.ict_dataset.get_block(*block_meta) for block_meta in top5_metas]
top5_start_end_doc = [bucket[idx][:3] for idx in top5_indices.squeeze()]
# top_k tuples of (block_tokens, block_pad_mask)
top5_block_data = [self.ict_dataset.get_block(*indices) for indices in top5_start_end_doc]
top5_tokens, top5_pad_masks = zip(*top5_block_data) top5_tokens, top5_pad_masks = zip(*top5_block_data)
all_top5_tokens.append(np.array(top5_tokens)) all_top5_tokens.append(np.array(top5_tokens))
all_top5_pad_masks.append(np.array(top5_pad_masks)) all_top5_pad_masks.append(np.array(top5_pad_masks))
# [batch_size x 5 x seq_length] # [batch_size x k x seq_length]
return np.array(all_top5_tokens), np.array(all_top5_pad_masks) return np.array(all_top5_tokens), np.array(all_top5_pad_masks)
......
...@@ -19,7 +19,7 @@ import torch ...@@ -19,7 +19,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from hashed_index import load_ict_checkpoint, get_ict_dataset from hashed_index import load_ict_checkpoint, get_ict_dataset
from megatron.data.realm_index import BlockData, RandProjectionLSHIndex from megatron.data.realm_index import BlockData, RandProjectionLSHIndex, FaissMIPSIndex
from megatron import get_args from megatron import get_args
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
...@@ -40,7 +40,9 @@ def model_provider(): ...@@ -40,7 +40,9 @@ def model_provider():
ict_model = load_ict_checkpoint() ict_model = load_ict_checkpoint()
ict_dataset = get_ict_dataset() ict_dataset = get_ict_dataset()
all_block_data = BlockData.load_from_file(args.block_data_path) all_block_data = BlockData.load_from_file(args.block_data_path)
hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path) # hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path)
hashed_index = FaissMIPSIndex(index_type='flat_l2', embed_size=128)
hashed_index.add_block_embed_data(all_block_data)
retriever = REALMRetriever(ict_model, ict_dataset, all_block_data, hashed_index) retriever = REALMRetriever(ict_model, ict_dataset, all_block_data, hashed_index)
# TODO: REALMBertModel should accept a path to a pretrained bert-base # TODO: REALMBertModel should accept a path to a pretrained bert-base
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment