Commit 8573ab35 authored by Neel Kant's avatar Neel Kant
Browse files

Use Faiss GPU index and report retrieval utility

parent e59496bf
...@@ -388,6 +388,7 @@ def _add_data_args(parser): ...@@ -388,6 +388,7 @@ def _add_data_args(parser):
help='Mask loss for the end of document tokens.') help='Mask loss for the end of document tokens.')
group.add_argument('--query-in-block-prob', type=float, default=0.1, group.add_argument('--query-in-block-prob', type=float, default=0.1,
help='Probability of keeping query in block for ICT dataset') help='Probability of keeping query in block for ICT dataset')
group.add_argument('--faiss-use-gpu', action='store_true')
return parser return parser
......
...@@ -33,9 +33,9 @@ class BlockData(object): ...@@ -33,9 +33,9 @@ class BlockData(object):
@classmethod @classmethod
def load_from_file(cls, fname): def load_from_file(cls, fname):
print(" > Unpickling block data") print("\n> Unpickling block data", flush=True)
state_dict = pickle.load(open(fname, 'rb')) state_dict = pickle.load(open(fname, 'rb'))
print(" > Finished unpickling") print(">> Finished unpickling block data\n", flush=True)
new_index = cls() new_index = cls()
new_index.embed_data = state_dict['embed_data'] new_index.embed_data = state_dict['embed_data']
...@@ -69,7 +69,7 @@ class BlockData(object): ...@@ -69,7 +69,7 @@ class BlockData(object):
shard_size = len(data['embed_data']) shard_size = len(data['embed_data'])
self.embed_data.update(data['embed_data']) self.embed_data.update(data['embed_data'])
self.meta_data.update(data['meta_data']) self.meta_data.update(data['meta_data'])
assert (len(self.embed_data) == old_size + shard_size) or (str(ignore_shard) in fname) # assert (len(self.embed_data) == old_size + shard_size) or (str(ignore_shard) in fname)
args = get_args() args = get_args()
with open(args.block_data_path, 'wb') as final_file: with open(args.block_data_path, 'wb') as final_file:
...@@ -82,6 +82,7 @@ class FaissMIPSIndex(object): ...@@ -82,6 +82,7 @@ class FaissMIPSIndex(object):
self.index_type = index_type self.index_type = index_type
self.embed_size = embed_size self.embed_size = embed_size
self.use_gpu = use_gpu self.use_gpu = use_gpu
self.id_map = dict()
# alsh # alsh
self.m = 5 self.m = 5
...@@ -95,12 +96,20 @@ class FaissMIPSIndex(object): ...@@ -95,12 +96,20 @@ class FaissMIPSIndex(object):
if self.index_type not in INDEX_TYPES: if self.index_type not in INDEX_TYPES:
raise ValueError("Invalid index type specified") raise ValueError("Invalid index type specified")
index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT) print("\n> Building index", flush=True)
self.block_mips_index = faiss.IndexIDMap(index) self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)
if not self.use_gpu:
self.block_mips_index = faiss.IndexIDMap(self.block_mips_index)
print(">> Finished building index", flush=True)
if self.use_gpu: if self.use_gpu:
res = faiss.StandardGpuResources() res = faiss.StandardGpuResources()
device = mpu.get_data_parallel_rank() # self.block_mips_index = faiss.index_cpu_to_gpu(res, device, self.block_mips_index)
self.block_mips_index = faiss.index_cpu_to_gpu(res, device, self.block_mips_index) config = faiss.GpuIndexFlatConfig()
config.device = torch.cuda.current_device()
config.useFloat16 = True
self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config)
print(">>> Loaded Faiss index on GPU {}\n".format(self.block_mips_index.getDevice()), flush=True)
def reset_index(self): def reset_index(self):
self._set_block_index() self._set_block_index()
...@@ -108,12 +117,16 @@ class FaissMIPSIndex(object): ...@@ -108,12 +117,16 @@ class FaissMIPSIndex(object):
def add_block_embed_data(self, all_block_data, clear_block_data=False): def add_block_embed_data(self, all_block_data, clear_block_data=False):
"""Add the embedding of each block to the underlying FAISS index""" """Add the embedding of each block to the underlying FAISS index"""
block_indices, block_embeds = zip(*all_block_data.embed_data.items()) block_indices, block_embeds = zip(*all_block_data.embed_data.items())
if self.use_gpu:
for i, idx in enumerate(block_indices):
self.id_map[i] = idx
if clear_block_data: if clear_block_data:
all_block_data.clear() all_block_data.clear()
if self.index_type == 'flat_l2': if self.use_gpu:
block_embeds = self.alsh_block_preprocess_fn(block_embeds) self.block_mips_index.add(np.float32(np.array(block_embeds)))
self.block_mips_index.add_with_ids(np.float32(np.array(block_embeds)), np.array(block_indices)) else:
self.block_mips_index.add_with_ids(np.float32(np.array(block_embeds)), np.array(block_indices))
def search_mips_index(self, query_embeds, top_k, reconstruct=True): def search_mips_index(self, query_embeds, top_k, reconstruct=True):
"""Get the top-k blocks by the index distance metric. """Get the top-k blocks by the index distance metric.
...@@ -123,14 +136,22 @@ class FaissMIPSIndex(object): ...@@ -123,14 +136,22 @@ class FaissMIPSIndex(object):
""" """
if self.index_type == 'flat_l2': if self.index_type == 'flat_l2':
query_embeds = self.alsh_query_preprocess_fn(query_embeds) query_embeds = self.alsh_query_preprocess_fn(query_embeds)
query_embeds = np.float32(query_embeds) query_embeds = np.float32(detach(query_embeds))
# query_embeds = query_embeds.float()
if reconstruct: with torch.no_grad():
top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k) if reconstruct:
return top_k_block_embeds top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k)
else: return top_k_block_embeds
distances, block_indices = self.block_mips_index.search(query_embeds, top_k) else:
return distances, block_indices distances, block_indices = self.block_mips_index.search(query_embeds, top_k)
if self.use_gpu:
fresh_indices = np.zeros(block_indices.shape)
for i in range(block_indices.shape[0]):
for j in range(block_indices.shape[1]):
fresh_indices[i, j] = self.id_map[block_indices[i, j]]
block_indices = fresh_indices
return distances, block_indices
# functions below are for ALSH, which currently isn't being used # functions below are for ALSH, which currently isn't being used
......
...@@ -34,7 +34,12 @@ def reduce_losses(losses): ...@@ -34,7 +34,12 @@ def reduce_losses(losses):
reduced_losses = torch.cat( reduced_losses = torch.cat(
[loss.clone().detach().view(1) for loss in losses]) [loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(reduced_losses, group=get_data_parallel_group()) torch.distributed.all_reduce(reduced_losses, group=get_data_parallel_group())
reduced_losses = reduced_losses / torch.distributed.get_world_size() args = get_args()
if args.max_training_rank is not None:
num_trainers = args.max_training_rank
else:
num_trainers = torch.distributed.get_world_size()
reduced_losses = reduced_losses / num_trainers
return reduced_losses return reduced_losses
......
...@@ -26,7 +26,8 @@ from megatron import print_rank_0 ...@@ -26,7 +26,8 @@ from megatron import print_rank_0
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import REALMBertModel, REALMRetriever from megatron.model import REALMBertModel, REALMRetriever
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import reduce_losses from megatron.utils import reduce_losses, report_memory
from megatron import mpu
from indexer import initialize_and_run_async_megatron from indexer import initialize_and_run_async_megatron
num_batches = 0 num_batches = 0
...@@ -37,11 +38,14 @@ def model_provider(): ...@@ -37,11 +38,14 @@ def model_provider():
args = get_args() args = get_args()
print_rank_0('building REALM models ...') print_rank_0('building REALM models ...')
ict_model = load_ict_checkpoint() try:
ict_model = load_ict_checkpoint(from_realm_chkpt=True)
except:
ict_model = load_ict_checkpoint(from_realm_chkpt=False)
ict_dataset = get_ict_dataset(use_titles=False) ict_dataset = get_ict_dataset(use_titles=False)
all_block_data = BlockData.load_from_file(args.block_data_path) all_block_data = BlockData.load_from_file(args.block_data_path)
# hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path) # hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path)
hashed_index = FaissMIPSIndex(index_type='flat_ip', embed_size=128) hashed_index = FaissMIPSIndex(index_type='flat_ip', embed_size=128, use_gpu=args.faiss_use_gpu)
hashed_index.add_block_embed_data(all_block_data) hashed_index.add_block_embed_data(all_block_data)
# top_k + 1 because we may need to exclude trivial candidate # top_k + 1 because we may need to exclude trivial candidate
...@@ -61,6 +65,9 @@ def get_batch(data_iterator): ...@@ -61,6 +65,9 @@ def get_batch(data_iterator):
data = None data = None
else: else:
data = next(data_iterator) data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype) data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack. # Unpack.
...@@ -90,9 +97,11 @@ def forward_step(data_iterator, model): ...@@ -90,9 +97,11 @@ def forward_step(data_iterator, model):
# Forward model. # Forward model.
lm_logits, block_probs = model(tokens, pad_mask, query_block_indices) lm_logits, block_probs = model(tokens, pad_mask, query_block_indices)
with torch.no_grad(): with torch.no_grad():
retrieval_utility = get_retrieval_utility(lm_logits, block_probs, labels, loss_mask) max_retrieval_utility, top_retrieval_utility, avg_retrieval_utility = mpu.checkpoint(
get_retrieval_utility, lm_logits, block_probs, labels, loss_mask)
# P(y|x) = sum_z(P(y|z, x) * P(z|x)) # P(y|x) = sum_z(P(y|z, x) * P(z|x))
null_block_probs = torch.mean(block_probs[:, block_probs.shape[1] - 1])
block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(lm_logits) block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(lm_logits)
lm_logits = torch.sum(lm_logits * block_probs, dim=1)[:, :labels.shape[1]] lm_logits = torch.sum(lm_logits * block_probs, dim=1)[:, :labels.shape[1]]
...@@ -101,9 +110,13 @@ def forward_step(data_iterator, model): ...@@ -101,9 +110,13 @@ def forward_step(data_iterator, model):
lm_loss = torch.sum( lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
reduced_loss = reduce_losses([lm_loss, retrieval_utility]) reduced_loss = reduce_losses([lm_loss, max_retrieval_utility, top_retrieval_utility, avg_retrieval_utility, null_block_probs])
# torch.cuda.synchronize() # torch.cuda.synchronize()
return lm_loss, {'lm_loss': reduced_loss[0], 'retrieval_utility': reduced_loss[1]} return lm_loss, {'lm_loss': reduced_loss[0],
'max_ru': reduced_loss[1],
'top_ru': reduced_loss[2],
'avg_ru': reduced_loss[3],
'null_prob': reduced_loss[4]}
def get_retrieval_utility(lm_logits, block_probs, labels, loss_mask): def get_retrieval_utility(lm_logits, block_probs, labels, loss_mask):
...@@ -129,9 +142,10 @@ def get_retrieval_utility(lm_logits, block_probs, labels, loss_mask): ...@@ -129,9 +142,10 @@ def get_retrieval_utility(lm_logits, block_probs, labels, loss_mask):
retrieved_block_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() retrieved_block_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
retrieved_block_losses.append(retrieved_block_loss) retrieved_block_losses.append(retrieved_block_loss)
avg_retrieved_block_loss = torch.sum(torch.cuda.FloatTensor(retrieved_block_losses)) / (lm_logits.shape[1] - 1) avg_retrieved_block_loss = torch.sum(torch.cuda.FloatTensor(retrieved_block_losses)) / (lm_logits.shape[1] - 1)
max_retrieval_utility = null_block_loss - min(retrieved_block_losses)
retrieval_utility = null_block_loss - avg_retrieved_block_loss top_retrieval_utility = null_block_loss - retrieved_block_losses[0]
return retrieval_utility avg_retrieval_utility = null_block_loss - avg_retrieved_block_loss
return max_retrieval_utility, top_retrieval_utility, avg_retrieval_utility
def qa_forward_step(data_iterator, model): def qa_forward_step(data_iterator, model):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment