"tests/python/pytorch/dataloading/test_dataloader.py" did not exist on "f5eb80d221fec8690e8cfb087256671545bb9a5a"
Commit d2eabecb authored by Neel Kant's avatar Neel Kant
Browse files

Complete __getitem__ for InverseClozeDataset

parent 1c4e8955
...@@ -791,7 +791,7 @@ class bert_sentencepair_dataset(data.Dataset): ...@@ -791,7 +791,7 @@ class bert_sentencepair_dataset(data.Dataset):
def mask_token(self, idx, tokens, types, vocab_words, rng): def mask_token(self, idx, tokens, types, vocab_words, rng):
""" """
helper function to mask `idx` token from `tokens` according to helper function to mask `idx` token from `tokens` according to
section 3.3.1 of https://arxiv.org/pdf/1810.04805.pdf section 3.1.1 of https://arxiv.org/pdf/1810.04805.pdf
""" """
label = tokens[idx] label = tokens[idx]
if rng.random() < 0.8: if rng.random() < 0.8:
...@@ -856,15 +856,12 @@ class InverseClozeDataset(data.Dataset): ...@@ -856,15 +856,12 @@ class InverseClozeDataset(data.Dataset):
Arguments: Arguments:
ds (Dataset or array-like): data corpus to use for training ds (Dataset or array-like): data corpus to use for training
max_seq_len (int): maximum sequence length to use for a target sentence max_seq_len (int): maximum sequence length to use for a target sentence
mask_lm_prob (float): proportion of tokens to mask for masked LM
max_preds_per_seq (int): Maximum number of masked tokens per sentence pair. Default: math.ceil(max_seq_len*mask_lm_prob/10)*10
short_seq_prob (float): Proportion of sentence pairs purposefully shorter than max_seq_len short_seq_prob (float): Proportion of sentence pairs purposefully shorter than max_seq_len
dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1) dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1)
""" """
def __init__(self, def __init__(self,
ds, ds,
max_seq_len=512, max_seq_len=512,
mask_lm_prob=.15,
max_preds_per_seq=None, max_preds_per_seq=None,
short_seq_prob=.01, short_seq_prob=.01,
dataset_size=None, dataset_size=None,
...@@ -877,10 +874,6 @@ class InverseClozeDataset(data.Dataset): ...@@ -877,10 +874,6 @@ class InverseClozeDataset(data.Dataset):
self.vocab_words = list(self.tokenizer.text_token_vocab.values()) self.vocab_words = list(self.tokenizer.text_token_vocab.values())
self.ds.SetTokenizer(None) self.ds.SetTokenizer(None)
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.mask_lm_prob = mask_lm_prob
if max_preds_per_seq is None:
max_preds_per_seq = math.ceil(max_seq_len*mask_lm_prob /10)*10
self.max_preds_per_seq = max_preds_per_seq
self.short_seq_prob = short_seq_prob self.short_seq_prob = short_seq_prob
self.dataset_size = dataset_size self.dataset_size = dataset_size
if self.dataset_size is None: if self.dataset_size is None:
...@@ -889,9 +882,6 @@ class InverseClozeDataset(data.Dataset): ...@@ -889,9 +882,6 @@ class InverseClozeDataset(data.Dataset):
if not self.presplit_sentences: if not self.presplit_sentences:
nltk.download('punkt', download_dir="./nltk") nltk.download('punkt', download_dir="./nltk")
self.weighted = weighted self.weighted = weighted
self.get_weighting()
def get_weighting(self):
if self.weighted: if self.weighted:
if hasattr(self.ds, 'is_lazy') and self.ds.is_lazy: if hasattr(self.ds, 'is_lazy') and self.ds.is_lazy:
lens = np.array(self.ds.lens) lens = np.array(self.ds.lens)
...@@ -907,7 +897,7 @@ class InverseClozeDataset(data.Dataset): ...@@ -907,7 +897,7 @@ class InverseClozeDataset(data.Dataset):
idx = np_rng.randint(self.total_len) idx = np_rng.randint(self.total_len)
return bisect_right(self.weighting, idx) return bisect_right(self.weighting, idx)
else: else:
return np_rng.randint(self.ds_len) return np_rng.randint(self.ds_len - 1)
def __len__(self): def __len__(self):
return self.dataset_size return self.dataset_size
...@@ -917,15 +907,24 @@ class InverseClozeDataset(data.Dataset): ...@@ -917,15 +907,24 @@ class InverseClozeDataset(data.Dataset):
rng = random.Random(idx) rng = random.Random(idx)
np_rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)]) np_rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)])
# get seq length # get seq length. Save 2 tokens for beginning and end
target_seq_length = self.max_seq_len target_seq_length = self.max_seq_len - 2
if rng.random() < self.short_seq_prob: if rng.random() < self.short_seq_prob:
target_seq_length = rng.randint(2, target_seq_length) target_seq_length = rng.randint(2, target_seq_length)
input_data, context_data, doc_idx = self.get_input_and_context(target_seq_length, rng, np_rng) input_data, context_data = self.get_input_and_context(target_seq_length, rng, np_rng)
input_tokens, input_token_types, input_pad_mask = input_data
# get other documents too context_tokens, context_token_types, context_pad_mask = context_data
# return sample
sample = {
'input_text': np.array(input_tokens),
'input_types': np.array(input_token_types),
'input_pad_mask': np.array(input_pad_mask),
'context_text': np.array(context_tokens),
'context_types': np.array(context_token_types),
'context_pad_mask': np.array(context_pad_mask)
}
return sample
def get_sentence_split_doc(self, idx): def get_sentence_split_doc(self, idx):
"""fetch document at index idx and split into sentences""" """fetch document at index idx and split into sentences"""
...@@ -950,17 +949,15 @@ class InverseClozeDataset(data.Dataset): ...@@ -950,17 +949,15 @@ class InverseClozeDataset(data.Dataset):
def get_input_and_context(self, target_seq_length, rng, np_rng): def get_input_and_context(self, target_seq_length, rng, np_rng):
"""fetches a sentence and its surrounding context""" """fetches a sentence and its surrounding context"""
doc = doc_idx = None doc = None
while doc is None: while doc is None:
if self.weighted: doc_idx = self.get_weighted_samples(np_rng)
doc_idx = self.get_weighted_samples(np_rng)
else:
doc_idx = rng.randint(0, self.ds_len - 1)
# doc is a list of sentences # doc is a list of sentences
doc = self.get_sentence_split_doc(doc_idx) doc = self.get_sentence_split_doc(doc_idx)
if not doc: if not doc:
doc = None doc = None
# set up and tokenize the entire selected document
num_sentences = len(doc) num_sentences = len(doc)
all_token_lists = [] all_token_lists = []
all_token_type_lists = [] all_token_type_lists = []
...@@ -972,9 +969,10 @@ class InverseClozeDataset(data.Dataset): ...@@ -972,9 +969,10 @@ class InverseClozeDataset(data.Dataset):
sentence_token_lens = [len(l) for l in all_token_lists] sentence_token_lens = [len(l) for l in all_token_lists]
inclusion_mask = [True] * num_sentences inclusion_mask = [True] * num_sentences
# select a random sentence from the document as input
input_sentence_idx = rng.randint(0, len(all_token_lists) - 1) input_sentence_idx = rng.randint(0, len(all_token_lists) - 1)
input_sentence_tokens = all_token_lists[input_sentence_idx].copy() input_tokens = all_token_lists[input_sentence_idx].copy()
input_sentence_token_types = all_token_type_lists[input_sentence_idx].copy() input_token_types = all_token_type_lists[input_sentence_idx].copy()
# 10% of the time, the input sentence is left in the context. # 10% of the time, the input sentence is left in the context.
# The other 90% of the time, remove it. # The other 90% of the time, remove it.
...@@ -994,42 +992,27 @@ class InverseClozeDataset(data.Dataset): ...@@ -994,42 +992,27 @@ class InverseClozeDataset(data.Dataset):
inclusion_mask[num_sentences - view_radius] = False inclusion_mask[num_sentences - view_radius] = False
remove_preceding = not remove_preceding remove_preceding = not remove_preceding
# assemble the tokens and token types of the context
context_tokens = list(itertools.chain( context_tokens = list(itertools.chain(
*[l for i, l in enumerate(all_token_lists) if inclusion_mask[i]])) *[l for i, l in enumerate(all_token_lists) if inclusion_mask[i]]))
context_token_types = list(itertools.chain( context_token_types = list(itertools.chain(
*[l for i, l in enumerate(all_token_type_lists) if inclusion_mask[i]])) *[l for i, l in enumerate(all_token_type_lists) if inclusion_mask[i]]))
return (input_sentence_tokens, input_sentence_token_types), (context_tokens, context_token_types), doc_idx # concatenate 'CLS' and 'SEP' tokens and add extra token types
input_tokens, input_token_types, input_pad_mask = self.concat_and_pad_tokens(
def calc_seq_len(self, max_seq_len): input_tokens, input_token_types)
return max_seq_len - 3 context_tokens, context_token_types, context_pad_mask = self.concat_and_pad_tokens(
context_tokens, context_token_types)
def mask_token(self, idx, tokens, types, vocab_words, rng):
"""
helper function to mask `idx` token from `tokens` according to
section 3.3.1 of https://arxiv.org/pdf/1810.04805.pdf
"""
label = tokens[idx]
if rng.random() < 0.8:
new_label = self.tokenizer.get_command('MASK').Id
else:
if rng.random() < 0.5:
new_label = label
else:
new_label = rng.choice(vocab_words)
tokens[idx] = new_label
return label return (input_tokens, input_token_types, input_pad_mask), \
(context_tokens, context_token_types, context_pad_mask)
def pad_seq(self, seq): def concat_and_pad_tokens(self, tokens, token_types):
"""helper function to pad sequence pair""" """concat with special tokens and pad sequence to self.max_seq_len"""
num_pad = max(0, self.max_seq_len - len(seq)) tokens = [self.tokenizer.get_command('ENC').Id] + tokens + [self.tokenizer.get_command('sep').Id]
pad_mask = [0] * len(seq) + [1] * num_pad token_types = [token_types[0]] + token_types + [token_types[0]]
seq += [self.tokenizer.get_command('pad').Id] * num_pad
return seq, pad_mask
def concat_tokens(self, tokens_a, token_types_a, tokens_b, token_types_b): num_pad = max(0, self.max_seq_len - len(tokens))
tokens = [self.tokenizer.get_command('ENC').Id] + tokens_a + [self.tokenizer.get_command('sep').Id] + tokens_b + [self.tokenizer.get_command('sep').Id] pad_mask = [0] * len(tokens) + [1] * num_pad
token_types = [token_types_a[0]] + token_types_a + [token_types_a[0]] + token_types_b + [token_types_b[0]] tokens += [self.tokenizer.get_command('pad').Id] * num_pad
return tokens, token_types return tokens, token_types, pad_mask
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment