Commit dbbd6c75 authored by Matthew Carrigan's avatar Matthew Carrigan
Browse files

Replaced some randints with cleaner randranges, and added a helpful

error for users whose corpus is just one giant document.
parent 61674333
...@@ -4,7 +4,7 @@ from tqdm import tqdm, trange ...@@ -4,7 +4,7 @@ from tqdm import tqdm, trange
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import shelve import shelve
from random import random, randint, shuffle, choice, sample from random import random, randrange, randint, shuffle, choice, sample
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
import numpy as np import numpy as np
import json import json
...@@ -30,6 +30,8 @@ class DocumentDatabase: ...@@ -30,6 +30,8 @@ class DocumentDatabase:
self.reduce_memory = reduce_memory self.reduce_memory = reduce_memory
def add_document(self, document): def add_document(self, document):
if not document:
return
if self.reduce_memory: if self.reduce_memory:
current_idx = len(self.doc_lengths) current_idx = len(self.doc_lengths)
self.document_shelf[str(current_idx)] = document self.document_shelf[str(current_idx)] = document
...@@ -49,11 +51,11 @@ class DocumentDatabase: ...@@ -49,11 +51,11 @@ class DocumentDatabase:
self._precalculate_doc_weights() self._precalculate_doc_weights()
rand_start = self.doc_cumsum[current_idx] rand_start = self.doc_cumsum[current_idx]
rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx] rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx]
sentence_index = randint(rand_start, rand_end-1) % self.cumsum_max sentence_index = randrange(rand_start, rand_end) % self.cumsum_max
sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right') sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right')
else: else:
# If we don't use sentence weighting, then every doc has an equal chance to be chosen # If we don't use sentence weighting, then every doc has an equal chance to be chosen
sampled_doc_index = current_idx + randint(1, len(self.doc_lengths)-1) sampled_doc_index = (current_idx + randrange(1, len(self.doc_lengths))) % len(self.doc_lengths)
assert sampled_doc_index != current_idx assert sampled_doc_index != current_idx
if self.reduce_memory: if self.reduce_memory:
return self.document_shelf[str(sampled_doc_index)] return self.document_shelf[str(sampled_doc_index)]
...@@ -170,7 +172,7 @@ def create_instances_from_document( ...@@ -170,7 +172,7 @@ def create_instances_from_document(
# (first) sentence. # (first) sentence.
a_end = 1 a_end = 1
if len(current_chunk) >= 2: if len(current_chunk) >= 2:
a_end = randint(1, len(current_chunk) - 1) a_end = randrange(1, len(current_chunk))
tokens_a = [] tokens_a = []
for j in range(a_end): for j in range(a_end):
...@@ -186,7 +188,7 @@ def create_instances_from_document( ...@@ -186,7 +188,7 @@ def create_instances_from_document(
# Sample a random document, with longer docs being sampled more frequently # Sample a random document, with longer docs being sampled more frequently
random_document = doc_database.sample_doc(current_idx=doc_idx, sentence_weighted=True) random_document = doc_database.sample_doc(current_idx=doc_idx, sentence_weighted=True)
random_start = randint(0, len(random_document) - 1) random_start = randrange(0, len(random_document))
for j in range(random_start, len(random_document)): for j in range(random_start, len(random_document)):
tokens_b.extend(random_document[j]) tokens_b.extend(random_document[j])
if len(tokens_b) >= target_b_length: if len(tokens_b) >= target_b_length:
...@@ -264,6 +266,14 @@ def main(): ...@@ -264,6 +266,14 @@ def main():
else: else:
tokens = tokenizer.tokenize(line) tokens = tokenizer.tokenize(line)
doc.append(tokens) doc.append(tokens)
if doc:
docs.add_document(doc) # If the last doc didn't end on a newline, make sure it still gets added
if len(docs) <= 1:
exit("ERROR: No document breaks were found in the input file! These are necessary to allow the script to "
"ensure that random NextSentences are not sampled from the same document. Please add blank lines to "
"indicate breaks between documents in your input file. If your dataset does not contain multiple "
"documents, blank lines can be inserted at any natural boundary, such as the ends of chapters, "
"sections or paragraphs.")
args.output_dir.mkdir(exist_ok=True) args.output_dir.mkdir(exist_ok=True)
for epoch in trange(args.epochs_to_generate, desc="Epoch"): for epoch in trange(args.epochs_to_generate, desc="Epoch"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment