Commit e949a5c5 authored by Neel Kant's avatar Neel Kant
Browse files

Make InverseClozeDataset more efficient

parent 6e856fac
......@@ -963,52 +963,47 @@ class InverseClozeDataset(data.Dataset):
# set up and tokenize the entire selected document
num_sentences = len(doc)
all_token_lists = []
all_token_type_lists = []
for sentence in doc:
tokens, token_types = self.sentence_tokenize(sentence, 0)
all_token_lists.append(tokens)
all_token_type_lists.append(token_types)
sentence_token_lens = [len(l) for l in all_token_lists]
inclusion_mask = [False] * num_sentences
padless_max_len = self.max_seq_len - 2
# select a random sentence from the document as input
input_sentence_idx = rng.randint(0, len(all_token_lists) - 1)
input_tokens = all_token_lists[input_sentence_idx].copy()[:target_seq_length]
input_token_types = all_token_type_lists[input_sentence_idx].copy()[:target_seq_length]
input_sentence_idx = rng.randint(num_sentences)
tokens, token_types = self.sentence_tokenize(doc[input_sentence_idx], 0)
input_tokens, input_token_types = tokens[:target_seq_length], token_types[:target_seq_length]
if not len(input_tokens) > 0:
continue
context_tokens, context_token_types = [], []
# 10% of the time, the input sentence is left in the context.
# The other 90% of the time, remove it.
if rng.random() < 0.1:
inclusion_mask[input_sentence_idx] = True
context_tokens = input_tokens.copy()
context_token_types = input_token_types.copy()
# parameters for examining sentences to remove from the context
view_preceding = True
view_radius = 1
while sum(s for i, s in enumerate(sentence_token_lens) if inclusion_mask[i]) < padless_max_len:
while len(context_tokens) < padless_max_len:
# keep removing sentences while the context is too large.
if view_preceding:
examine_idx = input_sentence_idx - view_radius
if examine_idx >= 0:
inclusion_mask[examine_idx] = True
new_tokens, new_token_types = self.sentence_tokenize(doc[examine_idx], 0)
context_tokens = new_tokens + context_tokens
context_token_types = new_token_types + context_token_types
else:
examine_idx = input_sentence_idx + view_radius
if examine_idx < num_sentences:
inclusion_mask[examine_idx] = True
new_tokens, new_token_types = self.sentence_tokenize(doc[examine_idx], 0)
context_tokens += new_tokens
context_token_types += new_token_types
view_radius += 1
view_preceding = not view_preceding
if view_radius > num_sentences:
break
# assemble the tokens and token types of the context
context_tokens = list(itertools.chain(
*[l for i, l in enumerate(all_token_lists) if inclusion_mask[i]]))[:padless_max_len]
context_token_types = list(itertools.chain(
*[l for i, l in enumerate(all_token_type_lists) if inclusion_mask[i]]))[:padless_max_len]
context_tokens = context_tokens[:padless_max_len]
context_token_types = context_token_types[:padless_max_len]
if not len(context_tokens) > 0:
continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment