Commit c125d247 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

built simple test for dataset

parent 7120e931
......@@ -11,11 +11,11 @@ from torch.utils.data import Dataset
# WILL BE REPLACED WITH JARED'S
class JaredDataset(object):
def __init__(self):
self.doc_idx = []
def __init__(self, doc_idx, sizes, sentences):
self.doc_idx = doc_idx
self.num_docs = len(self.doc_idx) - 1
self.sizes = []
self.sentences = []
self.sizes = sizes
self.sentences = sentences
def __getitem__(self, idx):
return self.sentences[idx]
......@@ -62,7 +62,7 @@ def build_training_samples_mapping(indexed_dataset, num_epochs, max_seq_length,
# Document sentences are in [sent_index_first, sent_index_last).
sent_index_first = indexed_dataset.doc_idx[doc_index]
sent_index_last = indexed_dataset.doc_idx[doc_index+1]
assert sent_index_last >= sent_index_first:
assert sent_index_last >= sent_index_first
# Empty docs.
if (sent_index_last - sent_index_first) == 0:
......@@ -82,7 +82,7 @@ def build_training_samples_mapping(indexed_dataset, num_epochs, max_seq_length,
# Loop through sentences.
sent_index = sent_index_first
target_seq_length = get_target_seq_length(max_num_tokens,
short_seq_prob, rng)
short_seq_prob, np_rng)
size = 0
while sent_index < sent_index_last:
......@@ -94,19 +94,22 @@ def build_training_samples_mapping(indexed_dataset, num_epochs, max_seq_length,
exceeded_target_size = (size >= target_seq_length)
# If only one sentence is left in the document.
only_one_sent_left = (sent_index == (sent_index_last - 1))
# If we have at least two sentneces.
have_more_than_one_sent = (sent_index - sent_index_first) > 1
# If we have reached end of the document.
reached_end_of_doc = (sent_index == sent_index_last)
if (exceeded_target_size and not only_one_sent_left) or \
reached_end_of_doc:
if (exceeded_target_size and not only_one_sent_left and
have_more_than_one_sent) or reached_end_of_doc:
assert (sent_index - sent_index_first) > 1
assert size > 1
# Add the sample.
samples.append([sent_index_first, sent_index])
samples.append([sent_index_first, sent_index,
target_seq_length])
# Reset indices
sent_index_first = sent_index
target_seq_length = get_target_seq_length(max_num_tokens,
short_seq_prob,
rng)
np_rng)
size = 0
num_sentences = 0
......@@ -132,16 +135,16 @@ def build_training_samples_mapping(indexed_dataset, num_epochs, max_seq_length,
class AlbertDataSet(Dataset):
def __init__(self, tokenizer, num_epochs, masked_lm_prob, max_seq_length
short_seq_prob, seed):
def __init__(self, indexed_dataset, tokenizer, num_epochs,
masked_lm_prob, max_seq_length, short_seq_prob, seed):
# Params to store.
self.seed = seed
self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length
# Build the indexed dataset.
self.indexed_dataset = JaredDataset()
# Indexed dataset.
self.indexed_dataset = indexed_dataset
# Build the samples mapping.
self.samples_mapping = build_training_samples_mapping(
......@@ -181,3 +184,48 @@ class AlbertDataSet(Dataset):
if __name__ == '__main__':
print('dataset ...')
from bert_tokenization import FullTokenizer
import json
import nltk
nltk.download('punkt')
def document_generator_provider(input_file):
with open(input_file, 'r') as ifile:
for document in ifile:
data = json.loads(document)
text = data['text']
sentences = []
for line in text.split('\n'):
if line != '\n':
sentences.extend(nltk.tokenize.sent_tokenize(line))
yield sentences
input_file = '/raid/mshoeybi/data/albert/sample/samples_11.json'
vocab_file = '/raid/mshoeybi/data/albert/bert_vocab/vocab.txt'
tokenizer = FullTokenizer(vocab_file, do_lower_case=True)
document_generator = document_generator_provider(input_file)
doc_idx = [0]
sizes = []
sentences_list = []
for sentences in document_generator:
doc_idx.append(len(sentences))
for sentence in sentences:
tokens = tokenizer.tokenize(sentence)
ids = tokenizer.convert_tokens_to_ids(tokens)
sizes.append(len(ids))
sentences_list.append(ids)
for i in range(1, len(doc_idx)):
doc_idx[i] += doc_idx[i-1]
indexed_dataset = JaredDataset(doc_idx, sizes, sentences_list)
dataset = AlbertDataSet(indexed_dataset=indexed_dataset,
tokenizer=tokenizer,
num_epochs=3,
masked_lm_prob=0.15,
max_seq_length=512,
short_seq_prob=0.1,
seed=1234)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment