"awq/git@developer.sourcefind.cn:OpenDAS/autoawq.git" did not exist on "eb85f67d36ccd72e7cdf5cdc29954a265603e062"
Commit c125d247 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

built simple test for dataset

parent 7120e931
...@@ -11,11 +11,11 @@ from torch.utils.data import Dataset ...@@ -11,11 +11,11 @@ from torch.utils.data import Dataset
# WILL BE REPLACED WITH JARED'S # WILL BE REPLACED WITH JARED'S
class JaredDataset(object): class JaredDataset(object):
def __init__(self): def __init__(self, doc_idx, sizes, sentences):
self.doc_idx = [] self.doc_idx = doc_idx
self.num_docs = len(self.doc_idx) - 1 self.num_docs = len(self.doc_idx) - 1
self.sizes = [] self.sizes = sizes
self.sentences = [] self.sentences = sentences
def __getitem__(self, idx): def __getitem__(self, idx):
return self.sentences[idx] return self.sentences[idx]
...@@ -62,7 +62,7 @@ def build_training_samples_mapping(indexed_dataset, num_epochs, max_seq_length, ...@@ -62,7 +62,7 @@ def build_training_samples_mapping(indexed_dataset, num_epochs, max_seq_length,
# Document sentences are in [sent_index_first, sent_index_last). # Document sentences are in [sent_index_first, sent_index_last).
sent_index_first = indexed_dataset.doc_idx[doc_index] sent_index_first = indexed_dataset.doc_idx[doc_index]
sent_index_last = indexed_dataset.doc_idx[doc_index+1] sent_index_last = indexed_dataset.doc_idx[doc_index+1]
assert sent_index_last >= sent_index_first: assert sent_index_last >= sent_index_first
# Empty docs. # Empty docs.
if (sent_index_last - sent_index_first) == 0: if (sent_index_last - sent_index_first) == 0:
...@@ -82,7 +82,7 @@ def build_training_samples_mapping(indexed_dataset, num_epochs, max_seq_length, ...@@ -82,7 +82,7 @@ def build_training_samples_mapping(indexed_dataset, num_epochs, max_seq_length,
# Loop through sentences. # Loop through sentences.
sent_index = sent_index_first sent_index = sent_index_first
target_seq_length = get_target_seq_length(max_num_tokens, target_seq_length = get_target_seq_length(max_num_tokens,
short_seq_prob, rng) short_seq_prob, np_rng)
size = 0 size = 0
while sent_index < sent_index_last: while sent_index < sent_index_last:
...@@ -94,19 +94,22 @@ def build_training_samples_mapping(indexed_dataset, num_epochs, max_seq_length, ...@@ -94,19 +94,22 @@ def build_training_samples_mapping(indexed_dataset, num_epochs, max_seq_length,
exceeded_target_size = (size >= target_seq_length) exceeded_target_size = (size >= target_seq_length)
# If only one sentence is left in the document. # If only one sentence is left in the document.
only_one_sent_left = (sent_index == (sent_index_last - 1)) only_one_sent_left = (sent_index == (sent_index_last - 1))
# If we have at least two sentneces.
have_more_than_one_sent = (sent_index - sent_index_first) > 1
# If we have reached end of the document. # If we have reached end of the document.
reached_end_of_doc = (sent_index == sent_index_last) reached_end_of_doc = (sent_index == sent_index_last)
if (exceeded_target_size and not only_one_sent_left) or \ if (exceeded_target_size and not only_one_sent_left and
reached_end_of_doc: have_more_than_one_sent) or reached_end_of_doc:
assert (sent_index - sent_index_first) > 1 assert (sent_index - sent_index_first) > 1
assert size > 1 assert size > 1
# Add the sample. # Add the sample.
samples.append([sent_index_first, sent_index]) samples.append([sent_index_first, sent_index,
target_seq_length])
# Reset indices # Reset indices
sent_index_first = sent_index sent_index_first = sent_index
target_seq_length = get_target_seq_length(max_num_tokens, target_seq_length = get_target_seq_length(max_num_tokens,
short_seq_prob, short_seq_prob,
rng) np_rng)
size = 0 size = 0
num_sentences = 0 num_sentences = 0
...@@ -132,16 +135,16 @@ def build_training_samples_mapping(indexed_dataset, num_epochs, max_seq_length, ...@@ -132,16 +135,16 @@ def build_training_samples_mapping(indexed_dataset, num_epochs, max_seq_length,
class AlbertDataSet(Dataset): class AlbertDataSet(Dataset):
def __init__(self, tokenizer, num_epochs, masked_lm_prob, max_seq_length def __init__(self, indexed_dataset, tokenizer, num_epochs,
short_seq_prob, seed): masked_lm_prob, max_seq_length, short_seq_prob, seed):
# Params to store. # Params to store.
self.seed = seed self.seed = seed
self.masked_lm_prob = masked_lm_prob self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
# Build the indexed dataset. # Indexed dataset.
self.indexed_dataset = JaredDataset() self.indexed_dataset = indexed_dataset
# Build the samples mapping. # Build the samples mapping.
self.samples_mapping = build_training_samples_mapping( self.samples_mapping = build_training_samples_mapping(
...@@ -181,3 +184,48 @@ class AlbertDataSet(Dataset): ...@@ -181,3 +184,48 @@ class AlbertDataSet(Dataset):
if __name__ == '__main__': if __name__ == '__main__':
print('dataset ...') print('dataset ...')
from bert_tokenization import FullTokenizer
import json
import nltk
nltk.download('punkt')
def document_generator_provider(input_file):
with open(input_file, 'r') as ifile:
for document in ifile:
data = json.loads(document)
text = data['text']
sentences = []
for line in text.split('\n'):
if line != '\n':
sentences.extend(nltk.tokenize.sent_tokenize(line))
yield sentences
input_file = '/raid/mshoeybi/data/albert/sample/samples_11.json'
vocab_file = '/raid/mshoeybi/data/albert/bert_vocab/vocab.txt'
tokenizer = FullTokenizer(vocab_file, do_lower_case=True)
document_generator = document_generator_provider(input_file)
doc_idx = [0]
sizes = []
sentences_list = []
for sentences in document_generator:
doc_idx.append(len(sentences))
for sentence in sentences:
tokens = tokenizer.tokenize(sentence)
ids = tokenizer.convert_tokens_to_ids(tokens)
sizes.append(len(ids))
sentences_list.append(ids)
for i in range(1, len(doc_idx)):
doc_idx[i] += doc_idx[i-1]
indexed_dataset = JaredDataset(doc_idx, sizes, sentences_list)
dataset = AlbertDataSet(indexed_dataset=indexed_dataset,
tokenizer=tokenizer,
num_epochs=3,
masked_lm_prob=0.15,
max_seq_length=512,
short_seq_prob=0.1,
seed=1234)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment