Commit f2094783 authored by Neel Kant's avatar Neel Kant
Browse files

Add REALMAnswerSpanModel and MLM features

parent c17d880c
......@@ -134,7 +134,7 @@ def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=
return model
def get_ict_dataset():
def get_ict_dataset(use_titles=True):
args = get_args()
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
......@@ -148,7 +148,8 @@ def get_ict_dataset():
max_num_samples=None,
max_seq_length=288, # doesn't matter
short_seq_prob=0.0001, # doesn't matter
seed=1
seed=1,
use_titles=use_titles
)
dataset = ICTDataset(**kwargs)
return dataset
......
......@@ -10,9 +10,7 @@ from torch.utils.data import Dataset
from megatron import get_tokenizer, print_rank_0, mpu
from megatron.data.bert_dataset import BertDataset
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
#qa_nlp = spacy.load('en_core_web_lg')
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy, is_start_piece
def build_simple_training_sample(sample, target_seq_length, max_seq_length,
......@@ -40,6 +38,169 @@ def build_simple_training_sample(sample, target_seq_length, max_seq_length,
return train_sample
qa_nlp = spacy.load('en_core_web_lg')
def salient_span_mask(tokens, vocab_id_list, vocab_id_to_token_dict,
cls_id, sep_id, mask_id, np_rng,
do_permutation=False):
"""Creates the predictions for the masked LM objective.
Note: Tokens here are vocab ids and not text tokens."""
cand_indexes = []
# Note(mingdachen): We create a list for recording if the piece is
# the starting piece of current token, where 1 means true, so that
# on-the-fly whole word masking is possible.
token_boundary = [0] * len(tokens)
for (i, token) in enumerate(tokens):
if token == cls_id or token == sep_id:
token_boundary[i] = 1
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if len(cand_indexes) >= 1 and not is_start_piece(vocab_id_to_token_dict[token]):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
if is_start_piece(vocab_id_to_token_dict[token]):
token_boundary[i] = 1
output_tokens = list(tokens)
masked_lm_positions = []
masked_lm_labels = []
ngram_indexes = []
for idx in range(len(cand_indexes)):
ngram_index = []
for n in ngrams:
ngram_index.append(cand_indexes[idx:idx + n])
ngram_indexes.append(ngram_index)
np_rng.shuffle(ngram_indexes)
masked_lms = []
covered_indexes = set()
for cand_index_set in ngram_indexes:
if len(masked_lms) >= num_to_predict:
break
if not cand_index_set:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for index_set in cand_index_set[0]:
for index in index_set:
if index in covered_indexes:
continue
n = np_rng.choice(ngrams[:len(cand_index_set)],
p=pvals[:len(cand_index_set)] /
pvals[:len(cand_index_set)].sum(keepdims=True))
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# Note(mingdachen):
# Repeatedly looking for a candidate that does not exceed the
# maximum number of predictions by trying shorter ngrams.
while len(masked_lms) + len(index_set) > num_to_predict:
if n == 0:
break
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if np_rng.random() < 0.8:
masked_token = mask_id
else:
# 10% of the time, keep original
if np_rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))]
output_tokens[index] = masked_token
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
assert len(masked_lms) <= num_to_predict
np_rng.shuffle(ngram_indexes)
select_indexes = set()
if do_permutation:
for cand_index_set in ngram_indexes:
if len(select_indexes) >= num_to_predict:
break
if not cand_index_set:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for index_set in cand_index_set[0]:
for index in index_set:
if index in covered_indexes or index in select_indexes:
continue
n = np.random.choice(ngrams[:len(cand_index_set)],
p=pvals[:len(cand_index_set)] /
pvals[:len(cand_index_set)].sum(keepdims=True))
index_set = sum(cand_index_set[n - 1], [])
n -= 1
while len(select_indexes) + len(index_set) > num_to_predict:
if n == 0:
break
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(select_indexes) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes or index in select_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
select_indexes.add(index)
assert len(select_indexes) <= num_to_predict
select_indexes = sorted(select_indexes)
permute_indexes = list(select_indexes)
np_rng.shuffle(permute_indexes)
orig_token = list(output_tokens)
for src_i, tgt_i in zip(select_indexes, permute_indexes):
output_tokens[src_i] = orig_token[tgt_i]
masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))
masked_lms = sorted(masked_lms, key=lambda x: x.index)
for p in masked_lms:
masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label)
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
class REALMDataset(Dataset):
"""Dataset containing simple masked sentences for masked language modeling.
......@@ -196,7 +357,7 @@ class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length,
short_seq_prob, seed):
short_seq_prob, seed, use_titles=True):
self.name = name
self.seed = seed
self.max_seq_length = max_seq_length
......@@ -204,6 +365,7 @@ class ICTDataset(Dataset):
self.title_dataset = title_dataset
self.short_seq_prob = short_seq_prob
self.rng = random.Random(self.seed)
self.use_titles = use_titles
self.samples_mapping = self.get_samples_mapping(
data_prefix, num_epochs, max_num_samples)
......@@ -220,15 +382,16 @@ class ICTDataset(Dataset):
def __getitem__(self, idx):
start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx]
title = list(self.title_dataset[int(doc_idx)])
if self.use_titles:
title = list(self.title_dataset[int(doc_idx)])
title_pad_offset = 3 + len(title)
else:
title = None
title_pad_offset = 2
block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
assert len(block) > 1
# avoid selecting the first or last sentence to be the query.
if len(block) == 2:
rand_sent_idx = int(self.rng.random() > 0.5)
else:
rand_sent_idx = self.rng.randint(1, len(block) - 2)
rand_sent_idx = self.rng.randint(0, len(block) - 1)
# keep the query in the context 10% of the time.
if self.rng.random() < 1:
......@@ -239,7 +402,7 @@ class ICTDataset(Dataset):
# still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length.
query = query[:self.max_seq_length - 2]
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
......@@ -279,9 +442,10 @@ class ICTDataset(Dataset):
def concat_and_pad_tokens(self, tokens, title=None):
"""concat with special tokens and pad sequence to self.max_seq_length"""
tokens = [self.cls_id] + tokens + [self.sep_id]
if title is not None:
tokens += title + [self.sep_id]
if title is None:
tokens = [self.cls_id] + tokens + [self.sep_id]
else:
tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id]
assert len(tokens) <= self.max_seq_length, len(tokens)
num_pad = self.max_seq_length - len(tokens)
......
......@@ -2,12 +2,79 @@ import numpy as np
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron.checkpointing import load_checkpoint
from megatron.data.realm_index import detach
from megatron.model import BertModel
from megatron.model.utils import get_linear_layer, init_method_normal
from megatron.module import MegatronModule
class REALMAnswerSpanModel(MegatronModule):
def __init__(self, realm_model, mlp_hidden_size=64):
super(REALMAnswerSpanModel, self).__init__()
self.realm_model = realm_model
self.mlp_hidden_size = mlp_hidden_size
args = get_args()
init_method = init_method_normal(args.init_method_std)
self.fc1 = get_linear_layer(2 * args.hidden_size, self.mlp_hidden_size, init_method)
self._fc1_key = 'fc1'
self.fc2 = get_linear_layer(self.mlp_hidden_size, 1, init_method)
self._fc2_key = 'fc2'
max_length = 10
self.start_ends = []
for length in range(max_length):
self.start_ends.extend([(i, i + length) for i in range(288 - length)])
def forward(self, question_tokens, question_attention_mask, answer_tokens, answer_token_lengths):
lm_logits, block_probs, topk_block_tokens = self.realm_model(
question_tokens, question_attention_mask, query_block_indices=None, return_topk_block_tokens=True)
batch_span_reps, batch_loss_masks = [], []
# go through batch one-by-one
for i in range(len(answer_token_lengths)):
answer_length = answer_token_lengths[i]
answer_span_tokens = answer_tokens[i][:answer_length]
span_reps, loss_masks = [], []
# go through the top k for the batch item
for logits, block_tokens in zip(lm_logits[i], topk_block_tokens[i]):
block_logits = logits[len(logits) / 2:]
span_starts = range(len(block_tokens) - (answer_length - 1))
# record the start, end indices of spans which match the answer
matching_indices = set([
(idx, idx + answer_length - 1) for idx in span_starts
if np.array_equal(block_tokens[idx:idx + answer_length], answer_span_tokens)
])
# create a mask for computing the loss on P(y | z, x)
# [num_spans]
loss_masks.append(torch.LongTensor([int(idx_pair in matching_indices) for idx_pair in self.start_ends]))
# get all of the candidate spans that need to be fed to MLP
# [num_spans x 2 * embed_size]
span_reps.append([torch.cat((block_logits[s], block_logits[e])) for (s, e) in self.start_ends])
# data for all k blocks for a single batch item
# [k x num_spans]
batch_loss_masks.append(torch.stack(loss_masks))
# [k x num_spans x 2 * embed_size]
batch_span_reps.append(torch.stack(span_reps))
# data for all batch items
# [batch_size x k x num_spans]
batch_loss_masks = torch.stack(batch_loss_masks)
batch_span_reps = torch.stack(batch_span_reps)
# [batch_size x k x num_spans]
batch_span_logits = self.fc2(self.fc1(batch_span_reps)).squeeze()
return batch_span_logits, batch_loss_masks, block_probs
# block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(lm_logits)
# lm_logits = torch.sum(lm_logits * block_probs, dim=1)
class REALMBertModel(MegatronModule):
def __init__(self, retriever):
super(REALMBertModel, self).__init__()
......@@ -24,11 +91,13 @@ class REALMBertModel(MegatronModule):
self.top_k = self.retriever.top_k
self._retriever_key = 'retriever'
def forward(self, tokens, attention_mask, query_block_indices):
def forward(self, tokens, attention_mask, query_block_indices, return_topk_block_tokens=False):
# [batch_size x k x seq_length]
topk_block_tokens, topk_block_attention_mask = self.retriever.retrieve_evidence_blocks(
tokens, attention_mask, query_block_indices=query_block_indices, include_null_doc=True)
batch_size = tokens.shape[0]
# create a copy in case it needs to be returned
ret_topk_block_tokens = np.array(topk_block_tokens)
seq_length = topk_block_tokens.shape[2]
topk_block_tokens = torch.cuda.LongTensor(topk_block_tokens).reshape(-1, seq_length)
......@@ -58,6 +127,10 @@ class REALMBertModel(MegatronModule):
# [batch_size x k x 2 * seq_length x vocab_size]
lm_logits, _ = self.lm_model.forward(all_tokens, all_attention_mask, all_token_types)
lm_logits = lm_logits.reshape(batch_size, self.top_k, 2 * seq_length, -1)
if return_topk_block_tokens:
return lm_logits, block_probs, ret_topk_block_tokens
return lm_logits, block_probs
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
......@@ -111,6 +184,11 @@ class REALMRetriever(MegatronModule):
query_embeds = detach(true_model.embed_query(query_tokens, query_pad_mask))
_, block_indices = self.hashed_index.search_mips_index(query_embeds, top_k=self.top_k, reconstruct=False)
all_topk_tokens, all_topk_pad_masks = [], []
# this will result in no candidate exclusion
if query_block_indices is None:
query_block_indices = [-1] * len(block_indices)
for query_idx, indices in enumerate(block_indices):
# [k x meta_dim]
# exclude trivial candidate if it appears, else just trim the weakest in the top-k
......
......@@ -83,11 +83,11 @@ def forward_step(data_iterator, model):
retrieval_scores = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask).float()
softmaxed = F.softmax(retrieval_scores, dim=1)
top5_vals, top5_indices = torch.topk(softmaxed, k=5, sorted=True)
sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True)
batch_size = softmaxed.shape[0]
top1_acc = torch.cuda.FloatTensor([sum([int(top5_indices[i, 0] == i) for i in range(batch_size)]) / batch_size])
top5_acc = torch.cuda.FloatTensor([sum([int(i in top5_indices[i]) for i in range(batch_size)]) / batch_size])
top1_acc = torch.cuda.FloatTensor([sum([int(sorted_indices[i, 0] == i) for i in range(batch_size)]) / batch_size])
top5_acc = torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :5]) for i in range(batch_size)]) / batch_size])
retrieval_loss = F.cross_entropy(softmaxed, torch.arange(batch_size).cuda())
reduced_losses = reduce_losses([retrieval_loss, top1_acc, top5_acc])
......
......@@ -38,7 +38,7 @@ def model_provider():
print_rank_0('building REALM models ...')
ict_model = load_ict_checkpoint()
ict_dataset = get_ict_dataset()
ict_dataset = get_ict_dataset(use_titles=False)
all_block_data = BlockData.load_from_file(args.block_data_path)
# hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path)
hashed_index = FaissMIPSIndex(index_type='flat_l2', embed_size=128)
......@@ -73,6 +73,11 @@ def get_batch(data_iterator):
return tokens, labels, loss_mask, pad_mask, query_block_indices
def get_qa_batch(data_iterator):
question_tokens, question_attention_mask, answer_tokens, answer_token_lengths = next(data_iterator)
return question_tokens, question_attention_mask, answer_tokens, answer_token_lengths
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
......@@ -101,6 +106,26 @@ def forward_step(data_iterator, model):
return lm_loss, {'lm_loss': reduced_loss[0]}
def qa_forward_step(data_iterator, model):
timers = get_timers()
# this dataset interface needs to be implemented
timers('batch generator').start()
question_tokens, question_attention_mask, answer_tokens, answer_token_lengths = get_qa_batch(data_iterator)
timers('batch generator').stop()
batch_span_logits, batch_loss_masks, block_probs = model(question_tokens, question_attention_mask,
answer_tokens, answer_token_lengths)
# [batch_size x k x num_spans]
block_probs = block_probs.unsqueeze(2).expand_as(batch_span_logits)
batch_span_probs = F.softmax(batch_span_logits, dim=2)
reduced_block_span_probs = torch.sum(batch_span_probs * block_probs, dim=1)
qa_span_loss_ = -torch.log(reduced_block_span_probs)
qa_span_loss = torch.sum(
qa_span_loss_.view(-1) * batch_loss_masks
)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid and test datasets."""
args = get_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment