Commit 2f7d666c authored by Neel Kant's avatar Neel Kant
Browse files

Add retrieval utility and autoresume for indexer

parent 9b9b8e01
import os
import sys
import time
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import get_args, get_adlr_autoresume, print_rank_0
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.bert_dataset import get_indexed_dataset_
from megatron.data.realm_dataset import ICTDataset
from megatron.data.realm_index import detach, BlockData, RandProjectionLSHIndex
from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex
from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron
from megatron.model import REALMRetriever
from megatron.training import get_model
from megatron.utils import check_adlr_autoresume_termination
from pretrain_bert_ict import get_batch, model_provider
from indexer_utils import set_index_com_file_ready, set_model_com_file_not_ready, check_model_com_file_ready
......@@ -40,14 +42,14 @@ def test_retriever():
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args()
model = load_ict_checkpoint(only_block_model=True)
model = load_ict_checkpoint()
model.eval()
dataset = get_ict_dataset()
block_data = BlockData.load_from_file(args.block_data_path)
mips_index = FaissMIPSIndex('flat_ip', 128)
mips_index.add_block_embed_data(block_data)
retriever = REALMRetriever(model, dataset, mips_index, top_k=5)
retriever = REALMRetriever(model, dataset, block_data, mips_index, top_k=5)
strs = [
"The last monarch from the house of windsor",
......@@ -71,7 +73,6 @@ def main():
dataset = get_ict_dataset()
data_iter = iter(get_one_epoch_dataloader(dataset))
all_block_data = BlockData()
hashed_index = RandProjectionLSHIndex(embed_size=128, num_buckets=32, whiten=True)
i = 1
total = 0
......@@ -103,18 +104,24 @@ def main():
if args.rank == 0:
all_block_data.consolidate_shards_and_save()
hashed_index.hash_whitened_block_embeds(all_block_data)
hashed_index.save_to_file()
else:
all_block_data.clear()
ran_once = True
set_index_com_file_ready()
torch.distributed.barrier()
while not check_model_com_file_ready():
time.sleep(5)
set_model_com_file_not_ready()
if args.async_indexer:
while not check_model_com_file_ready():
time.sleep(5)
autoresume = get_adlr_autoresume()
if autoresume.termination_requested():
print_rank_0(">>> autoresume termination request found!")
if torch.distributed.get_rank() == 0:
autoresume.request_resume()
print_rank_0(">>> training terminated. Returning")
sys.exit(0)
set_model_com_file_not_ready()
def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False):
......
......@@ -348,6 +348,8 @@ def _add_data_args(parser):
help='Path to pickled data structure for efficient block indexing')
group.add_argument('--block-top-k', type=int, default=5,
help='Number of blocks to use as top-k during retrieval')
group.add_argument('--async-indexer', action='store_true',
help='Whether the indexer job is running asynchronously with a trainer job')
group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
......
......@@ -93,6 +93,8 @@ def salient_span_mask(tokens, mask_id):
Note: Tokens here are vocab ids and not text tokens."""
tokenizer = get_tokenizer()
tokens_str = join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(tokens))
print("-" * 100)
print("TOKEN STR\n", tokens_str)
# need to get all named entities
entities = SPACY_NER(tokens_str).ents
......@@ -101,6 +103,7 @@ def salient_span_mask(tokens, mask_id):
return None
entity_idx = np.random.randint(0, len(entities))
selected_entity = entities[entity_idx]
print("SELECTED ENTITY\n", selected_entity.text)
token_pos_map = id_to_str_pos_map(tokens, tokenizer)
mask_start = mask_end = 0
......@@ -118,6 +121,7 @@ def salient_span_mask(tokens, mask_id):
for id_idx in masked_positions:
labels.append(tokens[id_idx])
output_tokens[id_idx] = mask_id
print("OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(output_tokens)))
return output_tokens, masked_positions, labels
......
......@@ -192,6 +192,8 @@ class REALMRetriever(MegatronModule):
with torch.no_grad():
if hasattr(self.ict_model, 'module'):
true_model = self.ict_model.module
if hasattr(true_model, 'module'):
true_model = true_model.module
else:
true_model = self.ict_model
query_embeds = detach(true_model.embed_query(query_tokens, query_pad_mask))
......
......@@ -87,8 +87,9 @@ def forward_step(data_iterator, model):
timers('batch generator').stop()
# Forward model.
# TODO: MAKE SURE PAD IS NOT 1 - PAD
lm_logits, block_probs = model(tokens, pad_mask, query_block_indices)
with torch.no_grad():
retrieval_utility = get_retrieval_utility(lm_logits, labels, loss_mask)
# P(y|x) = sum_z(P(y|z, x) * P(z|x))
block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(lm_logits)
......@@ -99,9 +100,32 @@ def forward_step(data_iterator, model):
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
reduced_loss = reduce_losses([lm_loss])
reduced_loss = reduce_losses([lm_loss, retrieval_utility])
torch.cuda.synchronize()
return lm_loss, {'lm_loss': reduced_loss[0]}
return lm_loss, {'lm_loss': reduced_loss[0], 'retrieval_utility': reduced_loss[1]}
def get_retrieval_utility(lm_logits, labels, loss_mask):
"""log P(y | z, x) - log P(y | null, x)"""
# [batch x seq_len x vocab_size]
null_block_lm_logits = lm_logits[:, -1, :, :]
null_block_loss_ = mpu.vocab_parallel_cross_entropy(null_block_lm_logits.contiguous().float(),
labels.contiguous())
null_block_loss = torch.sum(
null_block_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
retrieved_block_losses = []
for block_num in range(lm_logits.shape[1] - 1):
retrieved_block_lm_logits = lm_logits[:, block_num, :, :]
retrieved_block_loss_ = mpu.vocab_parallel_cross_entropy(retrieved_block_lm_logits.contiguous().float(),
labels.contiguous())
retrieved_block_loss = torch.sum(
retrieved_block_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
retrieved_block_losses.append(retrieved_block_loss)
avg_retrieved_block_loss = torch.sum(retrieved_block_losses) / (lm_logits.shape[1] - 1)
retrieval_utility = null_block_loss - avg_retrieved_block_loss
return retrieval_utility
def qa_forward_step(data_iterator, model):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment