Commit efcee158 authored by Neel Kant's avatar Neel Kant
Browse files

Add null block and exclude trivial block

parent 730266ca
......@@ -483,9 +483,15 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
title_dataset=title_dataset,
**kwargs
)
elif dataset_type == 'realm':
dataset = REALMDataset(
block_dataset=indexed_dataset,
title_dataset=title_dataset,
masked_lm_prob=masked_lm_prob,
**kwargs
)
else:
dataset_cls = BertDataset if dataset_type == 'standard_bert' else REALMDataset
dataset = dataset_cls(
dataset = BertDataset(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
**kwargs
......
......@@ -15,30 +15,10 @@ from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_co
#qa_nlp = spacy.load('en_core_web_lg')
class REALMDataset(BertDataset):
"""Dataset containing simple masked sentences for masked language modeling.
The dataset should yield sentences just like the regular BertDataset
However, this dataset also needs to be able to return a set of blocks
given their start and end indices.
Presumably
"""
def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed):
super(REALMDataset, self).__init__(name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed)
self.build_sample_fn = build_simple_training_sample
def build_simple_training_sample(sample, target_seq_length, max_seq_length,
vocab_id_list, vocab_id_to_token_dict,
cls_id, sep_id, mask_id, pad_id,
masked_lm_prob, np_rng):
tokens = list(itertools.chain(*sample))[:max_seq_length - 2]
tokens, tokentypes = create_single_tokens_and_tokentypes(tokens, cls_id, sep_id)
......@@ -60,6 +40,137 @@ def build_simple_training_sample(sample, target_seq_length, max_seq_length,
return train_sample
class REALMDataset(Dataset):
"""Dataset containing simple masked sentences for masked language modeling.
The dataset should yield sentences just like the regular BertDataset
However, this dataset also needs to be able to return a set of blocks
given their start and end indices.
Presumably
"""
def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed):
self.name = name
self.seed = seed
self.max_seq_length = max_seq_length
self.masked_lm_prob = masked_lm_prob
self.block_dataset = block_dataset
self.title_dataset = title_dataset
self.short_seq_prob = short_seq_prob
self.rng = random.Random(self.seed)
self.samples_mapping = self.get_samples_mapping(
data_prefix, num_epochs, max_num_samples)
self.tokenizer = get_tokenizer()
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
self.vocab_id_to_token_list = self.tokenizer.inv_vocab
self.cls_id = self.tokenizer.cls
self.sep_id = self.tokenizer.sep
self.mask_id = self.tokenizer.mask
self.pad_id = self.tokenizer.pad
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx]
seq_length = self.max_seq_length
block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
assert len(block) > 1
np_rng = np.random.RandomState(seed=(self.seed + idx))
sample = build_simple_training_sample(block, seq_length,
self.max_seq_length,
self.vocab_id_list,
self.vocab_id_to_token_list,
self.cls_id,
self.sep_id,
self.mask_id,
self.pad_id,
self.masked_lm_prob,
np_rng)
sample.update({'query_block_indices': np.array([block_idx])})
return sample
def get_samples_mapping(self, data_prefix, num_epochs, max_num_samples):
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
"or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(self.name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(self.max_seq_length)
indexmap_filename += '_{}s'.format(self.seed)
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types.
assert self.block_dataset.doc_idx.dtype == np.int64
assert self.block_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
print_rank_0(' > building samples index mapping for {} ...'.format(
self.name))
from megatron.data.dataset_utils import compile_helper
compile_helper()
from megatron.data import helpers
samples_mapping = helpers.build_blocks_mapping(
self.block_dataset.doc_idx,
self.block_dataset.sizes,
self.title_dataset.sizes,
num_epochs,
max_num_samples,
self.max_seq_length-3, # account for added tokens
self.seed,
verbose)
print_rank_0(' > done building samples index mapping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping
print_rank_0(' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'.format(
time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
samples_mapping.shape[0]))
return samples_mapping
def create_single_tokens_and_tokentypes(_tokens, cls_id, sep_id):
tokens = []
tokens.append(cls_id)
......@@ -160,6 +271,12 @@ class ICTDataset(Dataset):
return (block_tokens, block_pad_mask)
def get_null_block(self):
block, title = [], []
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return (block_tokens, block_pad_mask)
def concat_and_pad_tokens(self, tokens, title=None):
"""concat with special tokens and pad sequence to self.max_seq_length"""
tokens = [self.cls_id] + tokens + [self.sep_id]
......
......@@ -21,40 +21,43 @@ class REALMBertModel(MegatronModule):
self._lm_key = 'realm_lm'
self.retriever = retriever
self.top_k = self.retriever.top_k
self._retriever_key = 'retriever'
def forward(self, tokens, attention_mask):
# [batch_size x 5 x seq_length]
top5_block_tokens, top5_block_attention_mask = self.retriever.retrieve_evidence_blocks(tokens, attention_mask)
def forward(self, tokens, attention_mask, query_block_indices):
# [batch_size x k x seq_length]
topk_block_tokens, topk_block_attention_mask = self.retriever.retrieve_evidence_blocks(
tokens, attention_mask, query_block_indices=query_block_indices, include_null_doc=True)
batch_size = tokens.shape[0]
seq_length = top5_block_tokens.shape[2]
top5_block_tokens = torch.cuda.LongTensor(top5_block_tokens).reshape(-1, seq_length)
top5_block_attention_mask = torch.cuda.LongTensor(top5_block_attention_mask).reshape(-1, seq_length)
seq_length = topk_block_tokens.shape[2]
topk_block_tokens = torch.cuda.LongTensor(topk_block_tokens).reshape(-1, seq_length)
topk_block_attention_mask = torch.cuda.LongTensor(topk_block_attention_mask).reshape(-1, seq_length)
# [batch_size x 5 x embed_size]
# [batch_size x k x embed_size]
true_model = self.retriever.ict_model.module.module
fresh_block_logits = true_model.embed_block(top5_block_tokens, top5_block_attention_mask).reshape(batch_size, 5, -1)
fresh_block_logits = true_model.embed_block(topk_block_tokens, topk_block_attention_mask)
fresh_block_logits = fresh_block_logits.reshape(batch_size, self.top_k, -1)
# [batch_size x embed_size x 1]
query_logits = true_model.embed_query(tokens, attention_mask).unsqueeze(2)
# [batch_size x 5]
# [batch_size x k]
fresh_block_scores = torch.matmul(fresh_block_logits, query_logits).squeeze()
block_probs = F.softmax(fresh_block_scores, dim=1)
# [batch_size * 5 x seq_length]
tokens = torch.stack([tokens.unsqueeze(1)] * 5, dim=1).reshape(-1, seq_length)
attention_mask = torch.stack([attention_mask.unsqueeze(1)] * 5, dim=1).reshape(-1, seq_length)
# [batch_size * k x seq_length]
tokens = torch.stack([tokens.unsqueeze(1)] * self.top_k, dim=1).reshape(-1, seq_length)
attention_mask = torch.stack([attention_mask.unsqueeze(1)] * self.top_k, dim=1).reshape(-1, seq_length)
# [batch_size * 5 x 2 * seq_length]
all_tokens = torch.cat((tokens, top5_block_tokens), axis=1)
all_attention_mask = torch.cat((attention_mask, top5_block_attention_mask), axis=1)
# [batch_size * k x 2 * seq_length]
all_tokens = torch.cat((tokens, topk_block_tokens), axis=1)
all_attention_mask = torch.cat((attention_mask, topk_block_attention_mask), axis=1)
all_token_types = torch.zeros(all_tokens.shape).type(torch.int64).cuda()
# [batch_size x 5 x 2 * seq_length x vocab_size]
# [batch_size x k x 2 * seq_length x vocab_size]
lm_logits, _ = self.lm_model.forward(all_tokens, all_attention_mask, all_token_types)
lm_logits = lm_logits.reshape(batch_size, 5, 2 * seq_length, -1)
lm_logits = lm_logits.reshape(batch_size, self.top_k, 2 * seq_length, -1)
return lm_logits, block_probs
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
......@@ -101,24 +104,27 @@ class REALMRetriever(MegatronModule):
block_text = self.ict_dataset.decode_tokens(block)
print('\n > Block {}: {}'.format(i, block_text))
def retrieve_evidence_blocks(self, query_tokens, query_pad_mask):
def retrieve_evidence_blocks(self, query_tokens, query_pad_mask, query_block_indices=None, include_null_doc=False):
"""Embed blocks to be used in a forward pass"""
with torch.no_grad():
true_model = self.ict_model.module.module
query_embeds = detach(true_model.embed_query(query_tokens, query_pad_mask))
_, block_indices = self.hashed_index.search_mips_index(query_embeds, top_k=self.top_k, reconstruct=False)
all_top5_tokens, all_top5_pad_masks = [], []
for indices in block_indices:
all_topk_tokens, all_topk_pad_masks = [], []
for query_idx, indices in enumerate(block_indices):
# [k x meta_dim]
top5_metas = np.array([self.block_data.meta_data[idx] for idx in indices])
top5_block_data = [self.ict_dataset.get_block(*block_meta) for block_meta in top5_metas]
top5_tokens, top5_pad_masks = zip(*top5_block_data)
# exclude trivial candidate if it appears, else just trim the weakest in the top-k
topk_metas = [self.block_data.meta_data[idx] for idx in indices if idx != query_block_indices[query_idx]]
topk_block_data = [self.ict_dataset.get_block(*block_meta) for block_meta in topk_metas[:self.top_k - 1]]
if include_null_doc:
topk_block_data.append(self.ict_dataset.get_null_block())
topk_tokens, topk_pad_masks = zip(*topk_block_data)
all_top5_tokens.append(np.array(top5_tokens))
all_top5_pad_masks.append(np.array(top5_pad_masks))
all_topk_tokens.append(np.array(topk_tokens))
all_topk_pad_masks.append(np.array(topk_pad_masks))
# [batch_size x k x seq_length]
return np.array(all_top5_tokens), np.array(all_top5_pad_masks)
return np.array(all_topk_tokens), np.array(all_topk_pad_masks)
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
......
......@@ -44,8 +44,8 @@ def model_provider():
hashed_index = FaissMIPSIndex(index_type='flat_l2', embed_size=128)
hashed_index.add_block_embed_data(all_block_data)
retriever = REALMRetriever(ict_model, ict_dataset, all_block_data, hashed_index, args.block_top_k)
# TODO: REALMBertModel should accept a path to a pretrained bert-base
# top_k + 1 because we may need to exclude trivial candidate
retriever = REALMRetriever(ict_model, ict_dataset, all_block_data, hashed_index, args.block_top_k + 1)
model = REALMBertModel(retriever)
return model
......@@ -53,7 +53,7 @@ def model_provider():
def get_batch(data_iterator):
# Items and their type.
keys = ['tokens', 'labels', 'loss_mask', 'pad_mask']
keys = ['tokens', 'labels', 'loss_mask', 'pad_mask', 'query_block_indices']
datatype = torch.int64
# Broadcast data.
......@@ -68,8 +68,9 @@ def get_batch(data_iterator):
labels = data_b['labels'].long()
loss_mask = data_b['loss_mask'].long()
pad_mask = data_b['pad_mask'].long()
query_block_indices = data_b['query_block_indices'].long()
return tokens, labels, loss_mask, pad_mask
return tokens, labels, loss_mask, pad_mask, query_block_indices
def forward_step(data_iterator, model):
......@@ -78,16 +79,15 @@ def forward_step(data_iterator, model):
# Get the batch.
timers('batch generator').start()
tokens, labels, loss_mask, pad_mask = get_batch(data_iterator)
tokens, labels, loss_mask, pad_mask, query_block_indices = get_batch(data_iterator)
timers('batch generator').stop()
# Forward model.
# TODO: MAKE SURE PAD IS NOT 1 - PAD
lm_logits, block_probs = model(tokens, pad_mask)
lm_logits, block_probs = model(tokens, pad_mask, query_block_indices)
# P(y|x) = sum_z(P(y|z, x) * P(z|x))
block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(lm_logits)
#block_probs.register_hook(lambda x: print("block_probs: ", x.shape, flush=True))
lm_logits = torch.sum(lm_logits * block_probs, dim=1)[:, :labels.shape[1]]
lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(),
......@@ -95,7 +95,6 @@ def forward_step(data_iterator, model):
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
reduced_loss = reduce_losses([lm_loss])
torch.cuda.synchronize()
print(reduced_loss, flush=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment