Commit 2bba7f81 authored by Matthew Carrigan's avatar Matthew Carrigan
Browse files

Added a --reduce_memory option to shelve docs to disc instead of keeping them in memory.

parent 8733ffcb
from argparse import ArgumentParser
from pathlib import Path
from tqdm import tqdm, trange
from tempfile import TemporaryDirectory
import shelve
from random import random, randint, shuffle, choice, sample
from pytorch_pretrained_bert.tokenization import BertTokenizer
import numpy as np
import json
class DocumentDatabase:
def __init__(self, document_list):
self.document_list = document_list
self.doc_starts = {}
self.weighted_doc_samples = []
i = 0
for doc_idx, doc in enumerate(document_list):
self.doc_starts[doc_idx] = i
self.weighted_doc_samples.extend([doc_idx] * len(doc))
i += len(doc)
def __init__(self, reduce_memory=False, working_dir=None):
if reduce_memory:
if working_dir is None:
self.temp_dir = TemporaryDirectory()
self.working_dir = Path(self.temp_dir.name)
else:
self.temp_dir = None
self.working_dir = Path(working_dir)
self.working_dir.mkdir(parents=True, exist_ok=True)
self.document_shelf_filepath = self.working_dir / 'shelf.db'
self.document_shelf = shelve.open(str(self.document_shelf_filepath),
flag='n', protocol=-1)
self.documents = None
else:
self.documents = []
self.document_shelf = None
self.document_shelf_filepath = None
self.doc_lengths = []
self.doc_cumsum = None
self.cumsum_max = None
self.reduce_memory = reduce_memory
def add_document(self, document):
if self.reduce_memory:
current_idx = len(self.doc_lengths)
self.document_shelf[str(current_idx)] = document
else:
self.documents.append(document)
self.doc_lengths.append(len(document))
def _precalculate_doc_weights(self):
self.doc_cumsum = np.cumsum(self.doc_lengths)
self.cumsum_max = self.doc_cumsum[-1]
def sample_doc(self, current_idx, sentence_weighted=True):
# Uses the current iteration counter to ensure we don't sample the same doc twice
if sentence_weighted:
num_sentences = len(self.document_list[current_idx])
# This very painful line randomly selects a document, weighted by the number of sentences they contain,
# while guaranteeing that it won't return the original document
sampled_val = (
(self.doc_starts[current_idx] + num_sentences
+ randint(0, len(self.weighted_doc_samples) - num_sentences - 1))
% len(self.weighted_doc_samples))
sampled_doc_index = self.weighted_doc_samples[sampled_val]
# With sentence weighting, we sample docs proportionally to their sentence length
if self.doc_cumsum is None or len(self.doc_cumsum) != len(self.doc_lengths):
self._precalculate_doc_weights()
rand_start = self.doc_cumsum[current_idx]
rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx]
sentence_index = randint(rand_start, rand_end) % self.cumsum_max
sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right')
else:
# If we don't use sentence weighting, then every doc has an equal chance to be chosen
sampled_doc_index = current_idx + randint(1, len(self.document_list)-1)
sampled_doc_index = current_idx + randint(1, len(self.doc_lengths)-1)
assert sampled_doc_index != current_idx
return self.document_list[sampled_doc_index]
if self.reduce_memory:
return self.document_shelf[str(sampled_doc_index)]
else:
return self.documents[sampled_doc_index]
def __len__(self):
return len(self.document_list)
return len(self.doc_lengths)
def __getitem__(self, item):
return self.document_list[item]
if self.reduce_memory:
return self.document_shelf[str(item)]
else:
return self.documents[item]
def cleanup(self):
if self.document_shelf is not None:
self.document_shelf.close()
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
......@@ -200,6 +235,11 @@ def main():
"bert-base-multilingual", "bert-base-chinese"])
parser.add_argument("--do_lower_case", action="store_true")
parser.add_argument("--reduce_memory", action="store_true",
help="Reduce memory usage for large datasets by keeping data on disc rather than in memory")
parser.add_argument("--working_dir", type=Path, default=None,
help="Temporary directory to use for --reduce_memory. If not set, uses TemporaryDirectory()")
parser.add_argument("--epochs_to_generate", type=int, default=3,
help="Number of epochs of data to pregenerate")
parser.add_argument("--max_seq_len", type=int, default=128)
......@@ -212,31 +252,21 @@ def main():
args = parser.parse_args()
# TODO Add a low-memory / multiprocessing path for very large datasets
# In this path documents would be stored in a shelf after being tokenized, and multiple processes would convert
# those docs into training examples that would be written out on the fly. This would avoid the need to keep
# the whole training set in memory and would speed up dataset creation at the cost of code complexity.
# In addition, the finetuning script would need to be modified
# to store the training epochs as memmapped arrays.
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
vocab_list = list(tokenizer.vocab.keys())
docs = DocumentDatabase(reduce_memory=args.reduce_memory, working_dir=args.working_dir)
with args.train_corpus.open() as f:
docs = []
doc = []
for line in tqdm(f, desc="Loading Dataset"):
for line in tqdm(f, desc="Loading Dataset", unit=" lines"):
line = line.strip()
if line == "":
docs.append(doc)
docs.add_document(doc)
doc = []
else:
tokens = tokenizer.tokenize(line)
doc.append(tokens)
args.output_dir.mkdir(exist_ok=True)
docs = DocumentDatabase(docs)
# When choosing a random sentence, we should sample docs proportionally to the number of sentences they contain
# Google BERT doesn't do this, and as a result oversamples shorter docs
for epoch in trange(args.epochs_to_generate, desc="Epoch"):
epoch_filename = args.output_dir / f"epoch_{epoch}.json"
num_instances = 0
......@@ -257,6 +287,7 @@ def main():
"max_seq_len": args.max_seq_len
}
metrics_file.write(json.dumps(metrics))
docs.cleanup()
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment