Commit 6e856fac authored by Neel Kant's avatar Neel Kant
Browse files

Add while condition to InverseClozeDataset to protect against corner cases

parent 6f56b909
...@@ -910,7 +910,7 @@ class InverseClozeDataset(data.Dataset): ...@@ -910,7 +910,7 @@ class InverseClozeDataset(data.Dataset):
# get seq length. Save 2 tokens for beginning and end # get seq length. Save 2 tokens for beginning and end
target_seq_length = self.max_seq_len - 2 target_seq_length = self.max_seq_len - 2
if rng.random() < self.short_seq_prob: if rng.random() < self.short_seq_prob:
target_seq_length = rng.randint(2, target_seq_length) target_seq_length = rng.randint(5, target_seq_length)
input_data, context_data = self.get_input_and_context(target_seq_length, rng, np_rng) input_data, context_data = self.get_input_and_context(target_seq_length, rng, np_rng)
input_tokens, input_token_types, input_pad_mask = input_data input_tokens, input_token_types, input_pad_mask = input_data
...@@ -950,69 +950,78 @@ class InverseClozeDataset(data.Dataset): ...@@ -950,69 +950,78 @@ class InverseClozeDataset(data.Dataset):
def get_input_and_context(self, target_seq_length, rng, np_rng): def get_input_and_context(self, target_seq_length, rng, np_rng):
"""fetches a sentence and its surrounding context""" """fetches a sentence and its surrounding context"""
doc = None num_tries = 0
while doc is None: while num_tries < 20:
doc_idx = self.get_weighted_samples(np_rng) num_tries += 1
# doc is a list of sentences doc = None
doc = self.get_sentence_split_doc(doc_idx) while doc is None:
if not doc: doc_idx = self.get_weighted_samples(np_rng)
doc = None # doc is a list of sentences
doc = self.get_sentence_split_doc(doc_idx)
# set up and tokenize the entire selected document if not doc:
num_sentences = len(doc) doc = None
all_token_lists = []
all_token_type_lists = [] # set up and tokenize the entire selected document
for sentence in doc: num_sentences = len(doc)
tokens, token_types = self.sentence_tokenize(sentence, 0) all_token_lists = []
all_token_lists.append(tokens) all_token_type_lists = []
all_token_type_lists.append(token_types) for sentence in doc:
tokens, token_types = self.sentence_tokenize(sentence, 0)
sentence_token_lens = [len(l) for l in all_token_lists] all_token_lists.append(tokens)
inclusion_mask = [False] * num_sentences all_token_type_lists.append(token_types)
# select a random sentence from the document as input sentence_token_lens = [len(l) for l in all_token_lists]
input_sentence_idx = rng.randint(0, len(all_token_lists) - 1) inclusion_mask = [False] * num_sentences
input_tokens = all_token_lists[input_sentence_idx].copy()[:self.max_seq_len - 2] padless_max_len = self.max_seq_len - 2
input_token_types = all_token_type_lists[input_sentence_idx].copy()[:self.max_seq_len - 2]
# select a random sentence from the document as input
# 10% of the time, the input sentence is left in the context. input_sentence_idx = rng.randint(0, len(all_token_lists) - 1)
# The other 90% of the time, remove it. input_tokens = all_token_lists[input_sentence_idx].copy()[:target_seq_length]
if rng.random() < 0.1: input_token_types = all_token_type_lists[input_sentence_idx].copy()[:target_seq_length]
inclusion_mask[input_sentence_idx] = True if not len(input_tokens) > 0:
continue
# parameters for examining sentences to remove from the context
view_preceding = True # 10% of the time, the input sentence is left in the context.
view_radius = 1 # The other 90% of the time, remove it.
while sum(s for i, s in enumerate(sentence_token_lens) if inclusion_mask[i]) < self.max_seq_len - 2: if rng.random() < 0.1:
# keep removing sentences while the context is too large. inclusion_mask[input_sentence_idx] = True
if view_preceding:
examine_idx = input_sentence_idx - view_radius # parameters for examining sentences to remove from the context
if examine_idx >= 0: view_preceding = True
inclusion_mask[examine_idx] = True view_radius = 1
else: while sum(s for i, s in enumerate(sentence_token_lens) if inclusion_mask[i]) < padless_max_len:
examine_idx = input_sentence_idx + view_radius # keep removing sentences while the context is too large.
if examine_idx < num_sentences: if view_preceding:
inclusion_mask[examine_idx] = True examine_idx = input_sentence_idx - view_radius
view_radius += 1 if examine_idx >= 0:
view_preceding = not view_preceding inclusion_mask[examine_idx] = True
if view_radius > num_sentences: else:
break examine_idx = input_sentence_idx + view_radius
if examine_idx < num_sentences:
# assemble the tokens and token types of the context inclusion_mask[examine_idx] = True
context_tokens = list(itertools.chain( view_radius += 1
*[l for i, l in enumerate(all_token_lists) if inclusion_mask[i]]))[:self.max_seq_len - 2] view_preceding = not view_preceding
context_token_types = list(itertools.chain( if view_radius > num_sentences:
*[l for i, l in enumerate(all_token_type_lists) if inclusion_mask[i]]))[:self.max_seq_len - 2] break
# concatenate 'CLS' and 'SEP' tokens and add extra token types
input_tokens, input_token_types, input_pad_mask = self.concat_and_pad_tokens(
input_tokens, input_token_types)
context_tokens, context_token_types, context_pad_mask = self.concat_and_pad_tokens(
context_tokens, context_token_types)
return (input_tokens, input_token_types, input_pad_mask), \ # assemble the tokens and token types of the context
(context_tokens, context_token_types, context_pad_mask) context_tokens = list(itertools.chain(
*[l for i, l in enumerate(all_token_lists) if inclusion_mask[i]]))[:padless_max_len]
context_token_types = list(itertools.chain(
*[l for i, l in enumerate(all_token_type_lists) if inclusion_mask[i]]))[:padless_max_len]
if not len(context_tokens) > 0:
continue
# concatenate 'CLS' and 'SEP' tokens and add extra token types
input_tokens, input_token_types, input_pad_mask = self.concat_and_pad_tokens(
input_tokens, input_token_types)
context_tokens, context_token_types, context_pad_mask = self.concat_and_pad_tokens(
context_tokens, context_token_types)
return (input_tokens, input_token_types, input_pad_mask), \
(context_tokens, context_token_types, context_pad_mask)
else:
raise RuntimeError("Could not get a valid data point from InverseClozeDataset")
def concat_and_pad_tokens(self, tokens, token_types): def concat_and_pad_tokens(self, tokens, token_types):
"""concat with special tokens and pad sequence to self.max_seq_len""" """concat with special tokens and pad sequence to self.max_seq_len"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment