"src/vscode:/vscode.git/clone" did not exist on "6156cf8f22ebc61f66ebf60d8bf415cb3015fcc3"
Commit 6e856fac authored by Neel Kant's avatar Neel Kant
Browse files

Add while condition to InverseClozeDataset to protect against corner cases

parent 6f56b909
......@@ -910,7 +910,7 @@ class InverseClozeDataset(data.Dataset):
# get seq length. Save 2 tokens for beginning and end
target_seq_length = self.max_seq_len - 2
if rng.random() < self.short_seq_prob:
target_seq_length = rng.randint(2, target_seq_length)
target_seq_length = rng.randint(5, target_seq_length)
input_data, context_data = self.get_input_and_context(target_seq_length, rng, np_rng)
input_tokens, input_token_types, input_pad_mask = input_data
......@@ -950,69 +950,78 @@ class InverseClozeDataset(data.Dataset):
def get_input_and_context(self, target_seq_length, rng, np_rng):
"""fetches a sentence and its surrounding context"""
doc = None
while doc is None:
doc_idx = self.get_weighted_samples(np_rng)
# doc is a list of sentences
doc = self.get_sentence_split_doc(doc_idx)
if not doc:
doc = None
# set up and tokenize the entire selected document
num_sentences = len(doc)
all_token_lists = []
all_token_type_lists = []
for sentence in doc:
tokens, token_types = self.sentence_tokenize(sentence, 0)
all_token_lists.append(tokens)
all_token_type_lists.append(token_types)
sentence_token_lens = [len(l) for l in all_token_lists]
inclusion_mask = [False] * num_sentences
# select a random sentence from the document as input
input_sentence_idx = rng.randint(0, len(all_token_lists) - 1)
input_tokens = all_token_lists[input_sentence_idx].copy()[:self.max_seq_len - 2]
input_token_types = all_token_type_lists[input_sentence_idx].copy()[:self.max_seq_len - 2]
# 10% of the time, the input sentence is left in the context.
# The other 90% of the time, remove it.
if rng.random() < 0.1:
inclusion_mask[input_sentence_idx] = True
# parameters for examining sentences to remove from the context
view_preceding = True
view_radius = 1
while sum(s for i, s in enumerate(sentence_token_lens) if inclusion_mask[i]) < self.max_seq_len - 2:
# keep removing sentences while the context is too large.
if view_preceding:
examine_idx = input_sentence_idx - view_radius
if examine_idx >= 0:
inclusion_mask[examine_idx] = True
else:
examine_idx = input_sentence_idx + view_radius
if examine_idx < num_sentences:
inclusion_mask[examine_idx] = True
view_radius += 1
view_preceding = not view_preceding
if view_radius > num_sentences:
break
# assemble the tokens and token types of the context
context_tokens = list(itertools.chain(
*[l for i, l in enumerate(all_token_lists) if inclusion_mask[i]]))[:self.max_seq_len - 2]
context_token_types = list(itertools.chain(
*[l for i, l in enumerate(all_token_type_lists) if inclusion_mask[i]]))[:self.max_seq_len - 2]
# concatenate 'CLS' and 'SEP' tokens and add extra token types
input_tokens, input_token_types, input_pad_mask = self.concat_and_pad_tokens(
input_tokens, input_token_types)
context_tokens, context_token_types, context_pad_mask = self.concat_and_pad_tokens(
context_tokens, context_token_types)
num_tries = 0
while num_tries < 20:
num_tries += 1
doc = None
while doc is None:
doc_idx = self.get_weighted_samples(np_rng)
# doc is a list of sentences
doc = self.get_sentence_split_doc(doc_idx)
if not doc:
doc = None
# set up and tokenize the entire selected document
num_sentences = len(doc)
all_token_lists = []
all_token_type_lists = []
for sentence in doc:
tokens, token_types = self.sentence_tokenize(sentence, 0)
all_token_lists.append(tokens)
all_token_type_lists.append(token_types)
sentence_token_lens = [len(l) for l in all_token_lists]
inclusion_mask = [False] * num_sentences
padless_max_len = self.max_seq_len - 2
# select a random sentence from the document as input
input_sentence_idx = rng.randint(0, len(all_token_lists) - 1)
input_tokens = all_token_lists[input_sentence_idx].copy()[:target_seq_length]
input_token_types = all_token_type_lists[input_sentence_idx].copy()[:target_seq_length]
if not len(input_tokens) > 0:
continue
# 10% of the time, the input sentence is left in the context.
# The other 90% of the time, remove it.
if rng.random() < 0.1:
inclusion_mask[input_sentence_idx] = True
# parameters for examining sentences to remove from the context
view_preceding = True
view_radius = 1
while sum(s for i, s in enumerate(sentence_token_lens) if inclusion_mask[i]) < padless_max_len:
# keep removing sentences while the context is too large.
if view_preceding:
examine_idx = input_sentence_idx - view_radius
if examine_idx >= 0:
inclusion_mask[examine_idx] = True
else:
examine_idx = input_sentence_idx + view_radius
if examine_idx < num_sentences:
inclusion_mask[examine_idx] = True
view_radius += 1
view_preceding = not view_preceding
if view_radius > num_sentences:
break
return (input_tokens, input_token_types, input_pad_mask), \
(context_tokens, context_token_types, context_pad_mask)
# assemble the tokens and token types of the context
context_tokens = list(itertools.chain(
*[l for i, l in enumerate(all_token_lists) if inclusion_mask[i]]))[:padless_max_len]
context_token_types = list(itertools.chain(
*[l for i, l in enumerate(all_token_type_lists) if inclusion_mask[i]]))[:padless_max_len]
if not len(context_tokens) > 0:
continue
# concatenate 'CLS' and 'SEP' tokens and add extra token types
input_tokens, input_token_types, input_pad_mask = self.concat_and_pad_tokens(
input_tokens, input_token_types)
context_tokens, context_token_types, context_pad_mask = self.concat_and_pad_tokens(
context_tokens, context_token_types)
return (input_tokens, input_token_types, input_pad_mask), \
(context_tokens, context_token_types, context_pad_mask)
else:
raise RuntimeError("Could not get a valid data point from InverseClozeDataset")
def concat_and_pad_tokens(self, tokens, token_types):
"""concat with special tokens and pad sequence to self.max_seq_len"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment