Commit ca6b6687 authored by Neel Kant's avatar Neel Kant
Browse files

Fix InverseClozeDataset behavior (with commented out test code)

parent 423c51b0
......@@ -966,6 +966,7 @@ class InverseClozeDataset(data.Dataset):
padless_max_len = self.max_seq_len - 2
# select a random sentence from the document as input
# TODO: consider adding multiple input sentences.
input_sentence_idx = rng.randint(0, num_sentences - 1)
tokens, token_types = self.sentence_tokenize(doc[input_sentence_idx], 0)
input_tokens, input_token_types = tokens[:target_seq_length], token_types[:target_seq_length]
......@@ -976,14 +977,17 @@ class InverseClozeDataset(data.Dataset):
# 10% of the time, the input sentence is left in the context.
# The other 90% of the time, remove it.
if rng.random() < 0.1:
# if True:
context_tokens = input_tokens.copy()
context_token_types = input_token_types.copy()
# parameters for examining sentences to remove from the context
# TODO: test detokenized stuff, make sure it's the same doc in the same order.
# change preceding rng condition to always true
view_preceding = True
view_radius = 1
while len(context_tokens) < padless_max_len:
# keep removing sentences while the context is too large.
# keep adding sentences while the context can accommodate more.
if view_preceding:
examine_idx = input_sentence_idx - view_radius
if examine_idx >= 0:
......@@ -1001,6 +1005,25 @@ class InverseClozeDataset(data.Dataset):
if view_radius > num_sentences:
break
# detokenized_input = self.tokenizer.DecodeIds(input_tokens)
# detokenized_context = self.tokenizer.DecodeIds(context_tokens)
# encoded_sentences = [self.tokenizer.EncodeAsIds(s).tokenization for s in doc]
# full_document_encoded = list(itertools.chain(*encoded_sentences))
# detokenized_doc = self.tokenizer.DecodeIds(full_document_encoded)
# b1 = detokenized_input in detokenized_doc
# b2 = detokenized_context in detokenized_doc
# print("-" * 100)
# print('> input idx: {}'.format(input_sentence_idx))
# print('> input in doc: {}'.format(b1))
# print('> context in doc: {}'.format(b2))
# print('> input: {}'.format(detokenized_input))
# print('> context: {}'.format(detokenized_context))
# print('\n> doc: {}'.format(detokenized_doc))
# if not (b1 and b2):
# raise ValueError("you dun goofed")
# assemble the tokens and token types of the context
context_tokens = context_tokens[:padless_max_len]
context_token_types = context_token_types[:padless_max_len]
......
......@@ -215,9 +215,10 @@ class BertModel(MegatronModule):
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if not self.add_ict_head:
state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.add_binary_head:
state_dict_[self._binary_head_key] \
= self.binary_head.state_dict(destination, prefix, keep_vars)
......@@ -232,8 +233,9 @@ class BertModel(MegatronModule):
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
self.lm_head.load_state_dict(
state_dict[self._lm_head_key], strict=strict)
if not self.add_ict_head:
self.lm_head.load_state_dict(
state_dict[self._lm_head_key], strict=strict)
if self.add_binary_head:
self.binary_head.load_state_dict(
state_dict[self._binary_head_key], strict=strict)
......@@ -291,8 +293,8 @@ class ICTBertModel(MegatronModule):
def forward(self, input_tokens, input_attention_mask, input_types,
context_tokens, context_attention_mask, context_types):
question_ict_logits, _ = self.question_model.forward(input_tokens, input_attention_mask, input_types)
context_ict_logits, _ = self.context_model.forward(context_tokens, context_attention_mask, context_types)
question_ict_logits, _ = self.question_model.forward(input_tokens, 1 - input_attention_mask, input_types)
context_ict_logits, _ = self.context_model.forward(context_tokens, 1 - context_attention_mask, context_types)
# [batch x h] * [h x batch]
retrieval_scores = question_ict_logits.matmul(torch.transpose(context_ict_logits, 0, 1))
......
......@@ -93,14 +93,23 @@ def forward_step(data_iterator, model, args, timers):
timers('batch generator').stop()
# Forward model.
retrieval_scores = model(input_tokens, 1 - input_pad_mask, input_types,
context_tokens, 1 - context_pad_mask, context_types)
# TODO: important to make sure that everything, including padding mask is as expected here.
retrieval_scores = model(input_tokens, input_pad_mask, input_types,
context_tokens, context_pad_mask, context_types).float()
softmaxed = F.softmax(retrieval_scores, dim=0)
retrieval_loss = F.cross_entropy(softmaxed, torch.arange(softmaxed.shape[0]).cuda())
reduced_losses = reduce_losses([retrieval_loss])
softmaxed = F.softmax(retrieval_scores, dim=1)
top5_vals, top5_indices = torch.topk(softmaxed, k=5, sorted=True)
batch_size = softmaxed.shape[0]
return retrieval_loss, {'retrieval loss': reduced_losses[0]}
top1_acc = torch.cuda.FloatTensor([sum([int(top5_indices[i, 0] == i) for i in range(batch_size)]) / batch_size])
top5_acc = torch.cuda.FloatTensor([sum([int(i in top5_indices[i]) for i in range(batch_size)]) / batch_size])
retrieval_loss = F.cross_entropy(softmaxed, torch.arange(batch_size).cuda())
reduced_losses = reduce_losses([retrieval_loss, top1_acc, top5_acc])
return retrieval_loss, {'retrieval loss': reduced_losses[0],
'top1_acc': reduced_losses[1],
'top5_acc': reduced_losses[2]}
def get_train_val_test_data(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment