Commit 7d1ae644 authored by Matthew Carrigan's avatar Matthew Carrigan
Browse files

Added a --reduce_memory option to the training script to keep training

data on disc as a memmap rather than in memory
parent 2bba7f81
...@@ -6,6 +6,7 @@ import json ...@@ -6,6 +6,7 @@ import json
import random import random
import numpy as np import numpy as np
from collections import namedtuple from collections import namedtuple
from tempfile import TemporaryDirectory
from torch.utils.data import DataLoader, Dataset, RandomSampler from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
...@@ -53,8 +54,7 @@ def convert_example_to_features(example, tokenizer, max_seq_length): ...@@ -53,8 +54,7 @@ def convert_example_to_features(example, tokenizer, max_seq_length):
class PregeneratedDataset(Dataset): class PregeneratedDataset(Dataset):
def __init__(self, training_path, epoch, tokenizer, num_data_epochs): def __init__(self, training_path, epoch, tokenizer, num_data_epochs, reduce_memory=False):
# TODO Add an option to memmap the training data if needed (see note in pregenerate_training_data)
self.vocab = tokenizer.vocab self.vocab = tokenizer.vocab
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.epoch = epoch self.epoch = epoch
...@@ -65,6 +65,23 @@ class PregeneratedDataset(Dataset): ...@@ -65,6 +65,23 @@ class PregeneratedDataset(Dataset):
metrics = json.loads(metrics_file.read_text()) metrics = json.loads(metrics_file.read_text())
num_samples = metrics['num_training_examples'] num_samples = metrics['num_training_examples']
seq_len = metrics['max_seq_len'] seq_len = metrics['max_seq_len']
self.temp_dir = None
self.working_dir = None
if reduce_memory:
self.temp_dir = TemporaryDirectory()
self.working_dir = Path(self.temp_dir.name)
input_ids = np.memmap(filename=self.working_dir/'input_ids.memmap',
mode='w+', dtype=np.int32, shape=(num_samples, seq_len))
input_masks = np.memmap(filename=self.working_dir/'input_masks.memmap',
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
segment_ids = np.memmap(filename=self.working_dir/'input_masks.memmap',
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
lm_label_ids = np.memmap(filename=self.working_dir/'lm_label_ids.memmap',
shape=(num_samples, seq_len), mode='w+', dtype=np.int32)
lm_label_ids[:] = -1
is_nexts = np.memmap(filename=self.working_dir/'is_nexts.memmap',
shape=(num_samples,), mode='w+', dtype=np.bool)
else:
input_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.int32) input_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.int32)
input_masks = np.zeros(shape=(num_samples, seq_len), dtype=np.bool) input_masks = np.zeros(shape=(num_samples, seq_len), dtype=np.bool)
segment_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.bool) segment_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.bool)
...@@ -110,6 +127,8 @@ def main(): ...@@ -110,6 +127,8 @@ def main():
choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased", choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased",
"bert-base-multilingual", "bert-base-chinese"]) "bert-base-multilingual", "bert-base-chinese"])
parser.add_argument("--do_lower_case", action="store_true") parser.add_argument("--do_lower_case", action="store_true")
parser.add_argument("--reduce_memory", action="store_true",
help="Store training data as on-disc memmaps to massively reduce memory usage")
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs to train for") parser.add_argument("--epochs", type=int, default=3, help="Number of epochs to train for")
parser.add_argument("--local_rank", parser.add_argument("--local_rank",
...@@ -311,7 +330,5 @@ def main(): ...@@ -311,7 +330,5 @@ def main():
torch.save(model_to_save.state_dict(), str(output_model_file)) torch.save(model_to_save.state_dict(), str(output_model_file))
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -11,15 +11,10 @@ import json ...@@ -11,15 +11,10 @@ import json
class DocumentDatabase: class DocumentDatabase:
def __init__(self, reduce_memory=False, working_dir=None): def __init__(self, reduce_memory=False):
if reduce_memory: if reduce_memory:
if working_dir is None:
self.temp_dir = TemporaryDirectory() self.temp_dir = TemporaryDirectory()
self.working_dir = Path(self.temp_dir.name) self.working_dir = Path(self.temp_dir.name)
else:
self.temp_dir = None
self.working_dir = Path(working_dir)
self.working_dir.mkdir(parents=True, exist_ok=True)
self.document_shelf_filepath = self.working_dir / 'shelf.db' self.document_shelf_filepath = self.working_dir / 'shelf.db'
self.document_shelf = shelve.open(str(self.document_shelf_filepath), self.document_shelf = shelve.open(str(self.document_shelf_filepath),
flag='n', protocol=-1) flag='n', protocol=-1)
...@@ -237,8 +232,6 @@ def main(): ...@@ -237,8 +232,6 @@ def main():
parser.add_argument("--reduce_memory", action="store_true", parser.add_argument("--reduce_memory", action="store_true",
help="Reduce memory usage for large datasets by keeping data on disc rather than in memory") help="Reduce memory usage for large datasets by keeping data on disc rather than in memory")
parser.add_argument("--working_dir", type=Path, default=None,
help="Temporary directory to use for --reduce_memory. If not set, uses TemporaryDirectory()")
parser.add_argument("--epochs_to_generate", type=int, default=3, parser.add_argument("--epochs_to_generate", type=int, default=3,
help="Number of epochs of data to pregenerate") help="Number of epochs of data to pregenerate")
...@@ -254,7 +247,7 @@ def main(): ...@@ -254,7 +247,7 @@ def main():
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
vocab_list = list(tokenizer.vocab.keys()) vocab_list = list(tokenizer.vocab.keys())
docs = DocumentDatabase(reduce_memory=args.reduce_memory, working_dir=args.working_dir) docs = DocumentDatabase(reduce_memory=args.reduce_memory)
with args.train_corpus.open() as f: with args.train_corpus.open() as f:
doc = [] doc = []
for line in tqdm(f, desc="Loading Dataset", unit=" lines"): for line in tqdm(f, desc="Loading Dataset", unit=" lines"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment