Commit 6e856fac authored by Neel Kant's avatar Neel Kant
Browse files

Add while condition to InverseClozeDataset to protect against corner cases

parent 6f56b909
...@@ -910,7 +910,7 @@ class InverseClozeDataset(data.Dataset): ...@@ -910,7 +910,7 @@ class InverseClozeDataset(data.Dataset):
# get seq length. Save 2 tokens for beginning and end # get seq length. Save 2 tokens for beginning and end
target_seq_length = self.max_seq_len - 2 target_seq_length = self.max_seq_len - 2
if rng.random() < self.short_seq_prob: if rng.random() < self.short_seq_prob:
target_seq_length = rng.randint(2, target_seq_length) target_seq_length = rng.randint(5, target_seq_length)
input_data, context_data = self.get_input_and_context(target_seq_length, rng, np_rng) input_data, context_data = self.get_input_and_context(target_seq_length, rng, np_rng)
input_tokens, input_token_types, input_pad_mask = input_data input_tokens, input_token_types, input_pad_mask = input_data
...@@ -950,6 +950,9 @@ class InverseClozeDataset(data.Dataset): ...@@ -950,6 +950,9 @@ class InverseClozeDataset(data.Dataset):
def get_input_and_context(self, target_seq_length, rng, np_rng): def get_input_and_context(self, target_seq_length, rng, np_rng):
"""fetches a sentence and its surrounding context""" """fetches a sentence and its surrounding context"""
num_tries = 0
while num_tries < 20:
num_tries += 1
doc = None doc = None
while doc is None: while doc is None:
doc_idx = self.get_weighted_samples(np_rng) doc_idx = self.get_weighted_samples(np_rng)
...@@ -969,11 +972,14 @@ class InverseClozeDataset(data.Dataset): ...@@ -969,11 +972,14 @@ class InverseClozeDataset(data.Dataset):
sentence_token_lens = [len(l) for l in all_token_lists] sentence_token_lens = [len(l) for l in all_token_lists]
inclusion_mask = [False] * num_sentences inclusion_mask = [False] * num_sentences
padless_max_len = self.max_seq_len - 2
# select a random sentence from the document as input # select a random sentence from the document as input
input_sentence_idx = rng.randint(0, len(all_token_lists) - 1) input_sentence_idx = rng.randint(0, len(all_token_lists) - 1)
input_tokens = all_token_lists[input_sentence_idx].copy()[:self.max_seq_len - 2] input_tokens = all_token_lists[input_sentence_idx].copy()[:target_seq_length]
input_token_types = all_token_type_lists[input_sentence_idx].copy()[:self.max_seq_len - 2] input_token_types = all_token_type_lists[input_sentence_idx].copy()[:target_seq_length]
if not len(input_tokens) > 0:
continue
# 10% of the time, the input sentence is left in the context. # 10% of the time, the input sentence is left in the context.
# The other 90% of the time, remove it. # The other 90% of the time, remove it.
...@@ -983,7 +989,7 @@ class InverseClozeDataset(data.Dataset): ...@@ -983,7 +989,7 @@ class InverseClozeDataset(data.Dataset):
# parameters for examining sentences to remove from the context # parameters for examining sentences to remove from the context
view_preceding = True view_preceding = True
view_radius = 1 view_radius = 1
while sum(s for i, s in enumerate(sentence_token_lens) if inclusion_mask[i]) < self.max_seq_len - 2: while sum(s for i, s in enumerate(sentence_token_lens) if inclusion_mask[i]) < padless_max_len:
# keep removing sentences while the context is too large. # keep removing sentences while the context is too large.
if view_preceding: if view_preceding:
examine_idx = input_sentence_idx - view_radius examine_idx = input_sentence_idx - view_radius
...@@ -1000,9 +1006,11 @@ class InverseClozeDataset(data.Dataset): ...@@ -1000,9 +1006,11 @@ class InverseClozeDataset(data.Dataset):
# assemble the tokens and token types of the context # assemble the tokens and token types of the context
context_tokens = list(itertools.chain( context_tokens = list(itertools.chain(
*[l for i, l in enumerate(all_token_lists) if inclusion_mask[i]]))[:self.max_seq_len - 2] *[l for i, l in enumerate(all_token_lists) if inclusion_mask[i]]))[:padless_max_len]
context_token_types = list(itertools.chain( context_token_types = list(itertools.chain(
*[l for i, l in enumerate(all_token_type_lists) if inclusion_mask[i]]))[:self.max_seq_len - 2] *[l for i, l in enumerate(all_token_type_lists) if inclusion_mask[i]]))[:padless_max_len]
if not len(context_tokens) > 0:
continue
# concatenate 'CLS' and 'SEP' tokens and add extra token types # concatenate 'CLS' and 'SEP' tokens and add extra token types
input_tokens, input_token_types, input_pad_mask = self.concat_and_pad_tokens( input_tokens, input_token_types, input_pad_mask = self.concat_and_pad_tokens(
...@@ -1010,9 +1018,10 @@ class InverseClozeDataset(data.Dataset): ...@@ -1010,9 +1018,10 @@ class InverseClozeDataset(data.Dataset):
context_tokens, context_token_types, context_pad_mask = self.concat_and_pad_tokens( context_tokens, context_token_types, context_pad_mask = self.concat_and_pad_tokens(
context_tokens, context_token_types) context_tokens, context_token_types)
return (input_tokens, input_token_types, input_pad_mask), \ return (input_tokens, input_token_types, input_pad_mask), \
(context_tokens, context_token_types, context_pad_mask) (context_tokens, context_token_types, context_pad_mask)
else:
raise RuntimeError("Could not get a valid data point from InverseClozeDataset")
def concat_and_pad_tokens(self, tokens, token_types): def concat_and_pad_tokens(self, tokens, token_types):
"""concat with special tokens and pad sequence to self.max_seq_len""" """concat with special tokens and pad sequence to self.max_seq_len"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment