Commit 2bba7f81 authored by Matthew Carrigan's avatar Matthew Carrigan
Browse files

Added a --reduce_memory option to shelve docs to disc instead of keeping them in memory.

parent 8733ffcb
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
from tqdm import tqdm, trange from tqdm import tqdm, trange
from tempfile import TemporaryDirectory
import shelve
from random import random, randint, shuffle, choice, sample from random import random, randint, shuffle, choice, sample
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
import numpy as np
import json import json
class DocumentDatabase: class DocumentDatabase:
def __init__(self, document_list): def __init__(self, reduce_memory=False, working_dir=None):
self.document_list = document_list if reduce_memory:
self.doc_starts = {} if working_dir is None:
self.weighted_doc_samples = [] self.temp_dir = TemporaryDirectory()
i = 0 self.working_dir = Path(self.temp_dir.name)
for doc_idx, doc in enumerate(document_list): else:
self.doc_starts[doc_idx] = i self.temp_dir = None
self.weighted_doc_samples.extend([doc_idx] * len(doc)) self.working_dir = Path(working_dir)
i += len(doc) self.working_dir.mkdir(parents=True, exist_ok=True)
self.document_shelf_filepath = self.working_dir / 'shelf.db'
self.document_shelf = shelve.open(str(self.document_shelf_filepath),
flag='n', protocol=-1)
self.documents = None
else:
self.documents = []
self.document_shelf = None
self.document_shelf_filepath = None
self.doc_lengths = []
self.doc_cumsum = None
self.cumsum_max = None
self.reduce_memory = reduce_memory
def add_document(self, document):
if self.reduce_memory:
current_idx = len(self.doc_lengths)
self.document_shelf[str(current_idx)] = document
else:
self.documents.append(document)
self.doc_lengths.append(len(document))
def _precalculate_doc_weights(self):
self.doc_cumsum = np.cumsum(self.doc_lengths)
self.cumsum_max = self.doc_cumsum[-1]
def sample_doc(self, current_idx, sentence_weighted=True): def sample_doc(self, current_idx, sentence_weighted=True):
# Uses the current iteration counter to ensure we don't sample the same doc twice # Uses the current iteration counter to ensure we don't sample the same doc twice
if sentence_weighted: if sentence_weighted:
num_sentences = len(self.document_list[current_idx]) # With sentence weighting, we sample docs proportionally to their sentence length
# This very painful line randomly selects a document, weighted by the number of sentences they contain, if self.doc_cumsum is None or len(self.doc_cumsum) != len(self.doc_lengths):
# while guaranteeing that it won't return the original document self._precalculate_doc_weights()
sampled_val = ( rand_start = self.doc_cumsum[current_idx]
(self.doc_starts[current_idx] + num_sentences rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx]
+ randint(0, len(self.weighted_doc_samples) - num_sentences - 1)) sentence_index = randint(rand_start, rand_end) % self.cumsum_max
% len(self.weighted_doc_samples)) sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right')
sampled_doc_index = self.weighted_doc_samples[sampled_val]
else: else:
# If we don't use sentence weighting, then every doc has an equal chance to be chosen # If we don't use sentence weighting, then every doc has an equal chance to be chosen
sampled_doc_index = current_idx + randint(1, len(self.document_list)-1) sampled_doc_index = current_idx + randint(1, len(self.doc_lengths)-1)
assert sampled_doc_index != current_idx assert sampled_doc_index != current_idx
return self.document_list[sampled_doc_index] if self.reduce_memory:
return self.document_shelf[str(sampled_doc_index)]
else:
return self.documents[sampled_doc_index]
def __len__(self): def __len__(self):
return len(self.document_list) return len(self.doc_lengths)
def __getitem__(self, item): def __getitem__(self, item):
return self.document_list[item] if self.reduce_memory:
return self.document_shelf[str(item)]
else:
return self.documents[item]
def cleanup(self):
if self.document_shelf is not None:
self.document_shelf.close()
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens): def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
...@@ -200,6 +235,11 @@ def main(): ...@@ -200,6 +235,11 @@ def main():
"bert-base-multilingual", "bert-base-chinese"]) "bert-base-multilingual", "bert-base-chinese"])
parser.add_argument("--do_lower_case", action="store_true") parser.add_argument("--do_lower_case", action="store_true")
parser.add_argument("--reduce_memory", action="store_true",
help="Reduce memory usage for large datasets by keeping data on disc rather than in memory")
parser.add_argument("--working_dir", type=Path, default=None,
help="Temporary directory to use for --reduce_memory. If not set, uses TemporaryDirectory()")
parser.add_argument("--epochs_to_generate", type=int, default=3, parser.add_argument("--epochs_to_generate", type=int, default=3,
help="Number of epochs of data to pregenerate") help="Number of epochs of data to pregenerate")
parser.add_argument("--max_seq_len", type=int, default=128) parser.add_argument("--max_seq_len", type=int, default=128)
...@@ -212,31 +252,21 @@ def main(): ...@@ -212,31 +252,21 @@ def main():
args = parser.parse_args() args = parser.parse_args()
# TODO Add a low-memory / multiprocessing path for very large datasets
# In this path documents would be stored in a shelf after being tokenized, and multiple processes would convert
# those docs into training examples that would be written out on the fly. This would avoid the need to keep
# the whole training set in memory and would speed up dataset creation at the cost of code complexity.
# In addition, the finetuning script would need to be modified
# to store the training epochs as memmapped arrays.
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
vocab_list = list(tokenizer.vocab.keys()) vocab_list = list(tokenizer.vocab.keys())
docs = DocumentDatabase(reduce_memory=args.reduce_memory, working_dir=args.working_dir)
with args.train_corpus.open() as f: with args.train_corpus.open() as f:
docs = []
doc = [] doc = []
for line in tqdm(f, desc="Loading Dataset"): for line in tqdm(f, desc="Loading Dataset", unit=" lines"):
line = line.strip() line = line.strip()
if line == "": if line == "":
docs.append(doc) docs.add_document(doc)
doc = [] doc = []
else: else:
tokens = tokenizer.tokenize(line) tokens = tokenizer.tokenize(line)
doc.append(tokens) doc.append(tokens)
args.output_dir.mkdir(exist_ok=True) args.output_dir.mkdir(exist_ok=True)
docs = DocumentDatabase(docs)
# When choosing a random sentence, we should sample docs proportionally to the number of sentences they contain
# Google BERT doesn't do this, and as a result oversamples shorter docs
for epoch in trange(args.epochs_to_generate, desc="Epoch"): for epoch in trange(args.epochs_to_generate, desc="Epoch"):
epoch_filename = args.output_dir / f"epoch_{epoch}.json" epoch_filename = args.output_dir / f"epoch_{epoch}.json"
num_instances = 0 num_instances = 0
...@@ -257,6 +287,7 @@ def main(): ...@@ -257,6 +287,7 @@ def main():
"max_seq_len": args.max_seq_len "max_seq_len": args.max_seq_len
} }
metrics_file.write(json.dumps(metrics)) metrics_file.write(json.dumps(metrics))
docs.cleanup()
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment