Commit 1c4e8955 authored by Neel Kant's avatar Neel Kant
Browse files

Create class InverseClozeTask from bert_sentencepair_dataset and write get_input_and_context

parent 57f4a8a9
...@@ -18,6 +18,7 @@ import os ...@@ -18,6 +18,7 @@ import os
import time import time
from operator import itemgetter from operator import itemgetter
from bisect import bisect_right from bisect import bisect_right
import itertools
import json import json
import csv import csv
import math import math
...@@ -847,3 +848,188 @@ class bert_sentencepair_dataset(data.Dataset): ...@@ -847,3 +848,188 @@ class bert_sentencepair_dataset(data.Dataset):
mask_labels[idx] = label mask_labels[idx] = label
return (output_tokens, output_types), mask, mask_labels, pad_mask return (output_tokens, output_types), mask, mask_labels, pad_mask
class InverseClozeDataset(data.Dataset):
"""
Dataset containing sentences and various 'blocks' for an inverse cloze task.
Arguments:
ds (Dataset or array-like): data corpus to use for training
max_seq_len (int): maximum sequence length to use for a target sentence
mask_lm_prob (float): proportion of tokens to mask for masked LM
max_preds_per_seq (int): Maximum number of masked tokens per sentence pair. Default: math.ceil(max_seq_len*mask_lm_prob/10)*10
short_seq_prob (float): Proportion of sentence pairs purposefully shorter than max_seq_len
dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1)
"""
def __init__(self,
ds,
max_seq_len=512,
mask_lm_prob=.15,
max_preds_per_seq=None,
short_seq_prob=.01,
dataset_size=None,
presplit_sentences=False,
weighted=True,
**kwargs):
self.ds = ds
self.ds_len = len(self.ds)
self.tokenizer = self.ds.GetTokenizer()
self.vocab_words = list(self.tokenizer.text_token_vocab.values())
self.ds.SetTokenizer(None)
self.max_seq_len = max_seq_len
self.mask_lm_prob = mask_lm_prob
if max_preds_per_seq is None:
max_preds_per_seq = math.ceil(max_seq_len*mask_lm_prob /10)*10
self.max_preds_per_seq = max_preds_per_seq
self.short_seq_prob = short_seq_prob
self.dataset_size = dataset_size
if self.dataset_size is None:
self.dataset_size = self.ds_len * (self.ds_len-1)
self.presplit_sentences = presplit_sentences
if not self.presplit_sentences:
nltk.download('punkt', download_dir="./nltk")
self.weighted = weighted
self.get_weighting()
def get_weighting(self):
if self.weighted:
if hasattr(self.ds, 'is_lazy') and self.ds.is_lazy:
lens = np.array(self.ds.lens)
else:
lens = np.array([len(d['text']) if isinstance(d, dict) else len(d) for d in self.ds])
self.total_len = np.sum(lens)
self.weighting = list(accumulate(lens))
else:
self.weighting = None
def get_weighted_samples(self, np_rng):
if self.weighting is not None:
idx = np_rng.randint(self.total_len)
return bisect_right(self.weighting, idx)
else:
return np_rng.randint(self.ds_len)
def __len__(self):
return self.dataset_size
def __getitem__(self, idx):
# get rng state corresponding to index (allows deterministic random pair)
rng = random.Random(idx)
np_rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)])
# get seq length
target_seq_length = self.max_seq_len
if rng.random() < self.short_seq_prob:
target_seq_length = rng.randint(2, target_seq_length)
input_data, context_data, doc_idx = self.get_input_and_context(target_seq_length, rng, np_rng)
# get other documents too
# return sample
def get_sentence_split_doc(self, idx):
"""fetch document at index idx and split into sentences"""
document = self.ds[idx]
if isinstance(document, dict):
document = document['text']
lines = document.split('\n')
if self.presplit_sentences:
return [line for line in lines if line]
rtn = []
for line in lines:
if line != '':
rtn.extend(tokenize.sent_tokenize(line))
return rtn
def sentence_tokenize(self, sent, sentence_num=0, beginning=False, ending=False):
"""tokenize sentence and get token types"""
tokens = self.tokenizer.EncodeAsIds(sent).tokenization
str_type = 'str' + str(sentence_num)
token_types = [self.tokenizer.get_type(str_type).Id]*len(tokens)
return tokens, token_types
def get_input_and_context(self, target_seq_length, rng, np_rng):
"""fetches a sentence and its surrounding context"""
doc = doc_idx = None
while doc is None:
if self.weighted:
doc_idx = self.get_weighted_samples(np_rng)
else:
doc_idx = rng.randint(0, self.ds_len - 1)
# doc is a list of sentences
doc = self.get_sentence_split_doc(doc_idx)
if not doc:
doc = None
num_sentences = len(doc)
all_token_lists = []
all_token_type_lists = []
for sentence in doc:
tokens, token_types = self.sentence_tokenize(sentence, 0)
all_token_lists.append(tokens)
all_token_type_lists.append(token_types)
sentence_token_lens = [len(l) for l in all_token_lists]
inclusion_mask = [True] * num_sentences
input_sentence_idx = rng.randint(0, len(all_token_lists) - 1)
input_sentence_tokens = all_token_lists[input_sentence_idx].copy()
input_sentence_token_types = all_token_type_lists[input_sentence_idx].copy()
# 10% of the time, the input sentence is left in the context.
# The other 90% of the time, remove it.
if rng.random() > 0.1:
inclusion_mask[input_sentence_idx] = False
# parameters for examining sentences to remove from the context
remove_preceding = True
view_radius = 0
while sum(s for i, s in enumerate(sentence_token_lens) if inclusion_mask[i]) > target_seq_length:
# keep removing sentences while the context is too large.
if remove_preceding:
if view_radius < input_sentence_idx:
inclusion_mask[view_radius] = False
view_radius += 1
elif not remove_preceding and num_sentences - view_radius > input_sentence_idx:
inclusion_mask[num_sentences - view_radius] = False
remove_preceding = not remove_preceding
context_tokens = list(itertools.chain(
*[l for i, l in enumerate(all_token_lists) if inclusion_mask[i]]))
context_token_types = list(itertools.chain(
*[l for i, l in enumerate(all_token_type_lists) if inclusion_mask[i]]))
return (input_sentence_tokens, input_sentence_token_types), (context_tokens, context_token_types), doc_idx
def calc_seq_len(self, max_seq_len):
return max_seq_len - 3
def mask_token(self, idx, tokens, types, vocab_words, rng):
"""
helper function to mask `idx` token from `tokens` according to
section 3.3.1 of https://arxiv.org/pdf/1810.04805.pdf
"""
label = tokens[idx]
if rng.random() < 0.8:
new_label = self.tokenizer.get_command('MASK').Id
else:
if rng.random() < 0.5:
new_label = label
else:
new_label = rng.choice(vocab_words)
tokens[idx] = new_label
return label
def pad_seq(self, seq):
"""helper function to pad sequence pair"""
num_pad = max(0, self.max_seq_len - len(seq))
pad_mask = [0] * len(seq) + [1] * num_pad
seq += [self.tokenizer.get_command('pad').Id] * num_pad
return seq, pad_mask
def concat_tokens(self, tokens_a, token_types_a, tokens_b, token_types_b):
tokens = [self.tokenizer.get_command('ENC').Id] + tokens_a + [self.tokenizer.get_command('sep').Id] + tokens_b + [self.tokenizer.get_command('sep').Id]
token_types = [token_types_a[0]] + token_types_a + [token_types_a[0]] + token_types_b + [token_types_b[0]]
return tokens, token_types
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment