Commit 2f7d666c authored by Neel Kant's avatar Neel Kant
Browse files

Add retrieval utility and autoresume for indexer

parent 9b9b8e01
import os import os
import sys
import time import time
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args from megatron import get_args, get_adlr_autoresume, print_rank_0
from megatron import mpu from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.bert_dataset import get_indexed_dataset_ from megatron.data.bert_dataset import get_indexed_dataset_
from megatron.data.realm_dataset import ICTDataset from megatron.data.realm_dataset import ICTDataset
from megatron.data.realm_index import detach, BlockData, RandProjectionLSHIndex from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex
from megatron.data.samplers import DistributedBatchSampler from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.model import REALMRetriever from megatron.model import REALMRetriever
from megatron.training import get_model from megatron.training import get_model
from megatron.utils import check_adlr_autoresume_termination
from pretrain_bert_ict import get_batch, model_provider from pretrain_bert_ict import get_batch, model_provider
from indexer_utils import set_index_com_file_ready, set_model_com_file_not_ready, check_model_com_file_ready from indexer_utils import set_index_com_file_ready, set_model_com_file_not_ready, check_model_com_file_ready
...@@ -40,14 +42,14 @@ def test_retriever(): ...@@ -40,14 +42,14 @@ def test_retriever():
initialize_megatron(extra_args_provider=None, initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args() args = get_args()
model = load_ict_checkpoint(only_block_model=True) model = load_ict_checkpoint()
model.eval() model.eval()
dataset = get_ict_dataset() dataset = get_ict_dataset()
block_data = BlockData.load_from_file(args.block_data_path) block_data = BlockData.load_from_file(args.block_data_path)
mips_index = FaissMIPSIndex('flat_ip', 128) mips_index = FaissMIPSIndex('flat_ip', 128)
mips_index.add_block_embed_data(block_data) mips_index.add_block_embed_data(block_data)
retriever = REALMRetriever(model, dataset, mips_index, top_k=5) retriever = REALMRetriever(model, dataset, block_data, mips_index, top_k=5)
strs = [ strs = [
"The last monarch from the house of windsor", "The last monarch from the house of windsor",
...@@ -71,7 +73,6 @@ def main(): ...@@ -71,7 +73,6 @@ def main():
dataset = get_ict_dataset() dataset = get_ict_dataset()
data_iter = iter(get_one_epoch_dataloader(dataset)) data_iter = iter(get_one_epoch_dataloader(dataset))
all_block_data = BlockData() all_block_data = BlockData()
hashed_index = RandProjectionLSHIndex(embed_size=128, num_buckets=32, whiten=True)
i = 1 i = 1
total = 0 total = 0
...@@ -103,18 +104,24 @@ def main(): ...@@ -103,18 +104,24 @@ def main():
if args.rank == 0: if args.rank == 0:
all_block_data.consolidate_shards_and_save() all_block_data.consolidate_shards_and_save()
hashed_index.hash_whitened_block_embeds(all_block_data)
hashed_index.save_to_file()
else: else:
all_block_data.clear() all_block_data.clear()
ran_once = True ran_once = True
set_index_com_file_ready() set_index_com_file_ready()
torch.distributed.barrier() torch.distributed.barrier()
while not check_model_com_file_ready(): if args.async_indexer:
time.sleep(5) while not check_model_com_file_ready():
time.sleep(5)
set_model_com_file_not_ready() autoresume = get_adlr_autoresume()
if autoresume.termination_requested():
print_rank_0(">>> autoresume termination request found!")
if torch.distributed.get_rank() == 0:
autoresume.request_resume()
print_rank_0(">>> training terminated. Returning")
sys.exit(0)
set_model_com_file_not_ready()
def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False): def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False):
......
...@@ -348,6 +348,8 @@ def _add_data_args(parser): ...@@ -348,6 +348,8 @@ def _add_data_args(parser):
help='Path to pickled data structure for efficient block indexing') help='Path to pickled data structure for efficient block indexing')
group.add_argument('--block-top-k', type=int, default=5, group.add_argument('--block-top-k', type=int, default=5,
help='Number of blocks to use as top-k during retrieval') help='Number of blocks to use as top-k during retrieval')
group.add_argument('--async-indexer', action='store_true',
help='Whether the indexer job is running asynchronously with a trainer job')
group.add_argument('--split', type=str, default='969, 30, 1', group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,' help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split ' ' validation, and test split. For example the split '
......
...@@ -93,6 +93,8 @@ def salient_span_mask(tokens, mask_id): ...@@ -93,6 +93,8 @@ def salient_span_mask(tokens, mask_id):
Note: Tokens here are vocab ids and not text tokens.""" Note: Tokens here are vocab ids and not text tokens."""
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
tokens_str = join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(tokens)) tokens_str = join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(tokens))
print("-" * 100)
print("TOKEN STR\n", tokens_str)
# need to get all named entities # need to get all named entities
entities = SPACY_NER(tokens_str).ents entities = SPACY_NER(tokens_str).ents
...@@ -101,6 +103,7 @@ def salient_span_mask(tokens, mask_id): ...@@ -101,6 +103,7 @@ def salient_span_mask(tokens, mask_id):
return None return None
entity_idx = np.random.randint(0, len(entities)) entity_idx = np.random.randint(0, len(entities))
selected_entity = entities[entity_idx] selected_entity = entities[entity_idx]
print("SELECTED ENTITY\n", selected_entity.text)
token_pos_map = id_to_str_pos_map(tokens, tokenizer) token_pos_map = id_to_str_pos_map(tokens, tokenizer)
mask_start = mask_end = 0 mask_start = mask_end = 0
...@@ -118,6 +121,7 @@ def salient_span_mask(tokens, mask_id): ...@@ -118,6 +121,7 @@ def salient_span_mask(tokens, mask_id):
for id_idx in masked_positions: for id_idx in masked_positions:
labels.append(tokens[id_idx]) labels.append(tokens[id_idx])
output_tokens[id_idx] = mask_id output_tokens[id_idx] = mask_id
print("OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(output_tokens)))
return output_tokens, masked_positions, labels return output_tokens, masked_positions, labels
......
...@@ -192,6 +192,8 @@ class REALMRetriever(MegatronModule): ...@@ -192,6 +192,8 @@ class REALMRetriever(MegatronModule):
with torch.no_grad(): with torch.no_grad():
if hasattr(self.ict_model, 'module'): if hasattr(self.ict_model, 'module'):
true_model = self.ict_model.module true_model = self.ict_model.module
if hasattr(true_model, 'module'):
true_model = true_model.module
else: else:
true_model = self.ict_model true_model = self.ict_model
query_embeds = detach(true_model.embed_query(query_tokens, query_pad_mask)) query_embeds = detach(true_model.embed_query(query_tokens, query_pad_mask))
......
...@@ -87,8 +87,9 @@ def forward_step(data_iterator, model): ...@@ -87,8 +87,9 @@ def forward_step(data_iterator, model):
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model.
# TODO: MAKE SURE PAD IS NOT 1 - PAD
lm_logits, block_probs = model(tokens, pad_mask, query_block_indices) lm_logits, block_probs = model(tokens, pad_mask, query_block_indices)
with torch.no_grad():
retrieval_utility = get_retrieval_utility(lm_logits, labels, loss_mask)
# P(y|x) = sum_z(P(y|z, x) * P(z|x)) # P(y|x) = sum_z(P(y|z, x) * P(z|x))
block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(lm_logits) block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(lm_logits)
...@@ -99,9 +100,32 @@ def forward_step(data_iterator, model): ...@@ -99,9 +100,32 @@ def forward_step(data_iterator, model):
lm_loss = torch.sum( lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
reduced_loss = reduce_losses([lm_loss]) reduced_loss = reduce_losses([lm_loss, retrieval_utility])
torch.cuda.synchronize() torch.cuda.synchronize()
return lm_loss, {'lm_loss': reduced_loss[0]} return lm_loss, {'lm_loss': reduced_loss[0], 'retrieval_utility': reduced_loss[1]}
def get_retrieval_utility(lm_logits, labels, loss_mask):
"""log P(y | z, x) - log P(y | null, x)"""
# [batch x seq_len x vocab_size]
null_block_lm_logits = lm_logits[:, -1, :, :]
null_block_loss_ = mpu.vocab_parallel_cross_entropy(null_block_lm_logits.contiguous().float(),
labels.contiguous())
null_block_loss = torch.sum(
null_block_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
retrieved_block_losses = []
for block_num in range(lm_logits.shape[1] - 1):
retrieved_block_lm_logits = lm_logits[:, block_num, :, :]
retrieved_block_loss_ = mpu.vocab_parallel_cross_entropy(retrieved_block_lm_logits.contiguous().float(),
labels.contiguous())
retrieved_block_loss = torch.sum(
retrieved_block_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
retrieved_block_losses.append(retrieved_block_loss)
avg_retrieved_block_loss = torch.sum(retrieved_block_losses) / (lm_logits.shape[1] - 1)
retrieval_utility = null_block_loss - avg_retrieved_block_loss
return retrieval_utility
def qa_forward_step(data_iterator, model): def qa_forward_step(data_iterator, model):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment