Unverified Commit 7873d764 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #478 from Rocketknight1/master

Added a helpful error for users with single-document corpuses - fixes # 452
parents 61674333 dbbd6c75
...@@ -4,7 +4,7 @@ from tqdm import tqdm, trange ...@@ -4,7 +4,7 @@ from tqdm import tqdm, trange
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import shelve import shelve
from random import random, randint, shuffle, choice, sample from random import random, randrange, randint, shuffle, choice, sample
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
import numpy as np import numpy as np
import json import json
...@@ -30,6 +30,8 @@ class DocumentDatabase: ...@@ -30,6 +30,8 @@ class DocumentDatabase:
self.reduce_memory = reduce_memory self.reduce_memory = reduce_memory
def add_document(self, document): def add_document(self, document):
if not document:
return
if self.reduce_memory: if self.reduce_memory:
current_idx = len(self.doc_lengths) current_idx = len(self.doc_lengths)
self.document_shelf[str(current_idx)] = document self.document_shelf[str(current_idx)] = document
...@@ -49,11 +51,11 @@ class DocumentDatabase: ...@@ -49,11 +51,11 @@ class DocumentDatabase:
self._precalculate_doc_weights() self._precalculate_doc_weights()
rand_start = self.doc_cumsum[current_idx] rand_start = self.doc_cumsum[current_idx]
rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx] rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx]
sentence_index = randint(rand_start, rand_end-1) % self.cumsum_max sentence_index = randrange(rand_start, rand_end) % self.cumsum_max
sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right') sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right')
else: else:
# If we don't use sentence weighting, then every doc has an equal chance to be chosen # If we don't use sentence weighting, then every doc has an equal chance to be chosen
sampled_doc_index = current_idx + randint(1, len(self.doc_lengths)-1) sampled_doc_index = (current_idx + randrange(1, len(self.doc_lengths))) % len(self.doc_lengths)
assert sampled_doc_index != current_idx assert sampled_doc_index != current_idx
if self.reduce_memory: if self.reduce_memory:
return self.document_shelf[str(sampled_doc_index)] return self.document_shelf[str(sampled_doc_index)]
...@@ -170,7 +172,7 @@ def create_instances_from_document( ...@@ -170,7 +172,7 @@ def create_instances_from_document(
# (first) sentence. # (first) sentence.
a_end = 1 a_end = 1
if len(current_chunk) >= 2: if len(current_chunk) >= 2:
a_end = randint(1, len(current_chunk) - 1) a_end = randrange(1, len(current_chunk))
tokens_a = [] tokens_a = []
for j in range(a_end): for j in range(a_end):
...@@ -186,7 +188,7 @@ def create_instances_from_document( ...@@ -186,7 +188,7 @@ def create_instances_from_document(
# Sample a random document, with longer docs being sampled more frequently # Sample a random document, with longer docs being sampled more frequently
random_document = doc_database.sample_doc(current_idx=doc_idx, sentence_weighted=True) random_document = doc_database.sample_doc(current_idx=doc_idx, sentence_weighted=True)
random_start = randint(0, len(random_document) - 1) random_start = randrange(0, len(random_document))
for j in range(random_start, len(random_document)): for j in range(random_start, len(random_document)):
tokens_b.extend(random_document[j]) tokens_b.extend(random_document[j])
if len(tokens_b) >= target_b_length: if len(tokens_b) >= target_b_length:
...@@ -264,6 +266,14 @@ def main(): ...@@ -264,6 +266,14 @@ def main():
else: else:
tokens = tokenizer.tokenize(line) tokens = tokenizer.tokenize(line)
doc.append(tokens) doc.append(tokens)
if doc:
docs.add_document(doc) # If the last doc didn't end on a newline, make sure it still gets added
if len(docs) <= 1:
exit("ERROR: No document breaks were found in the input file! These are necessary to allow the script to "
"ensure that random NextSentences are not sampled from the same document. Please add blank lines to "
"indicate breaks between documents in your input file. If your dataset does not contain multiple "
"documents, blank lines can be inserted at any natural boundary, such as the ends of chapters, "
"sections or paragraphs.")
args.output_dir.mkdir(exist_ok=True) args.output_dir.mkdir(exist_ok=True)
for epoch in trange(args.epochs_to_generate, desc="Epoch"): for epoch in trange(args.epochs_to_generate, desc="Epoch"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment