"vscode:/vscode.git/clone" did not exist on "9c6a48c8c3e73ed3272c07f897930c17d5b0286d"
Commit dbbd6c75 authored by Matthew Carrigan's avatar Matthew Carrigan
Browse files

Replaced some randints with cleaner randranges, and added a helpful

error for users whose corpus is just one giant document.
parent 61674333
......@@ -4,7 +4,7 @@ from tqdm import tqdm, trange
from tempfile import TemporaryDirectory
import shelve
from random import random, randint, shuffle, choice, sample
from random import random, randrange, randint, shuffle, choice, sample
from pytorch_pretrained_bert.tokenization import BertTokenizer
import numpy as np
import json
......@@ -30,6 +30,8 @@ class DocumentDatabase:
self.reduce_memory = reduce_memory
def add_document(self, document):
if not document:
return
if self.reduce_memory:
current_idx = len(self.doc_lengths)
self.document_shelf[str(current_idx)] = document
......@@ -49,11 +51,11 @@ class DocumentDatabase:
self._precalculate_doc_weights()
rand_start = self.doc_cumsum[current_idx]
rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx]
sentence_index = randint(rand_start, rand_end-1) % self.cumsum_max
sentence_index = randrange(rand_start, rand_end) % self.cumsum_max
sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right')
else:
# If we don't use sentence weighting, then every doc has an equal chance to be chosen
sampled_doc_index = current_idx + randint(1, len(self.doc_lengths)-1)
sampled_doc_index = (current_idx + randrange(1, len(self.doc_lengths))) % len(self.doc_lengths)
assert sampled_doc_index != current_idx
if self.reduce_memory:
return self.document_shelf[str(sampled_doc_index)]
......@@ -170,7 +172,7 @@ def create_instances_from_document(
# (first) sentence.
a_end = 1
if len(current_chunk) >= 2:
a_end = randint(1, len(current_chunk) - 1)
a_end = randrange(1, len(current_chunk))
tokens_a = []
for j in range(a_end):
......@@ -186,7 +188,7 @@ def create_instances_from_document(
# Sample a random document, with longer docs being sampled more frequently
random_document = doc_database.sample_doc(current_idx=doc_idx, sentence_weighted=True)
random_start = randint(0, len(random_document) - 1)
random_start = randrange(0, len(random_document))
for j in range(random_start, len(random_document)):
tokens_b.extend(random_document[j])
if len(tokens_b) >= target_b_length:
......@@ -264,6 +266,14 @@ def main():
else:
tokens = tokenizer.tokenize(line)
doc.append(tokens)
if doc:
docs.add_document(doc) # If the last doc didn't end on a newline, make sure it still gets added
if len(docs) <= 1:
exit("ERROR: No document breaks were found in the input file! These are necessary to allow the script to "
"ensure that random NextSentences are not sampled from the same document. Please add blank lines to "
"indicate breaks between documents in your input file. If your dataset does not contain multiple "
"documents, blank lines can be inserted at any natural boundary, such as the ends of chapters, "
"sections or paragraphs.")
args.output_dir.mkdir(exist_ok=True)
for epoch in trange(args.epochs_to_generate, desc="Epoch"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment