Commit 08ff056c authored by Mayhul Arora's avatar Mayhul Arora
Browse files

Added option to use multiple workers to create training data for lm fine tuning

parent 98dc30b2
...@@ -3,6 +3,7 @@ from pathlib import Path ...@@ -3,6 +3,7 @@ from pathlib import Path
from tqdm import tqdm, trange from tqdm import tqdm, trange
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import shelve import shelve
from multiprocessing import Pool
from random import random, randrange, randint, shuffle, choice from random import random, randrange, randint, shuffle, choice
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
...@@ -264,6 +265,28 @@ def create_instances_from_document( ...@@ -264,6 +265,28 @@ def create_instances_from_document(
return instances return instances
def create_training_file(docs, vocab_list, args, epoch_num):
epoch_filename = args.output_dir / "epoch_{}.json".format(epoch_num)
num_instances = 0
with epoch_filename.open('w') as epoch_file:
for doc_idx in trange(len(docs), desc="Document"):
doc_instances = create_instances_from_document(
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob,
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq,
whole_word_mask=args.do_whole_word_mask, vocab_list=vocab_list)
doc_instances = [json.dumps(instance) for instance in doc_instances]
for instance in doc_instances:
epoch_file.write(instance + '\n')
num_instances += 1
metrics_file = args.output_dir / "epoch_{}_metrics.json".format(epoch_num)
with metrics_file.open('w') as metrics_file:
metrics = {
"num_training_examples": num_instances,
"max_seq_len": args.max_seq_len
}
metrics_file.write(json.dumps(metrics))
def main(): def main():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('--train_corpus', type=Path, required=True) parser.add_argument('--train_corpus', type=Path, required=True)
...@@ -277,6 +300,8 @@ def main(): ...@@ -277,6 +300,8 @@ def main():
parser.add_argument("--reduce_memory", action="store_true", parser.add_argument("--reduce_memory", action="store_true",
help="Reduce memory usage for large datasets by keeping data on disc rather than in memory") help="Reduce memory usage for large datasets by keeping data on disc rather than in memory")
parser.add_argument("--num_workers", type=int, default=1,
help="The number of workers to use to write the files")
parser.add_argument("--epochs_to_generate", type=int, default=3, parser.add_argument("--epochs_to_generate", type=int, default=3,
help="Number of epochs of data to pregenerate") help="Number of epochs of data to pregenerate")
parser.add_argument("--max_seq_len", type=int, default=128) parser.add_argument("--max_seq_len", type=int, default=128)
...@@ -289,6 +314,9 @@ def main(): ...@@ -289,6 +314,9 @@ def main():
args = parser.parse_args() args = parser.parse_args()
if args.num_workers > 1 and args.reduce_memory:
raise ValueError("Cannot use multiple workers while reducing memory")
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
vocab_list = list(tokenizer.vocab.keys()) vocab_list = list(tokenizer.vocab.keys())
with DocumentDatabase(reduce_memory=args.reduce_memory) as docs: with DocumentDatabase(reduce_memory=args.reduce_memory) as docs:
...@@ -312,26 +340,14 @@ def main(): ...@@ -312,26 +340,14 @@ def main():
"sections or paragraphs.") "sections or paragraphs.")
args.output_dir.mkdir(exist_ok=True) args.output_dir.mkdir(exist_ok=True)
for epoch in trange(args.epochs_to_generate, desc="Epoch"):
epoch_filename = args.output_dir / f"epoch_{epoch}.json" if args.num_workers > 1:
num_instances = 0 writer_workers = Pool(min(args.num_workers, args.epochs_to_generate))
with epoch_filename.open('w') as epoch_file: arguments = [(docs, vocab_list, args, idx) for idx in range(args.epochs_to_generate)]
for doc_idx in trange(len(docs), desc="Document"): writer_workers.starmap(create_training_file, arguments)
doc_instances = create_instances_from_document( else:
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob, for epoch in trange(args.epochs_to_generate, desc="Epoch"):
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq, create_training_file(docs, vocab_list, args, epoch)
whole_word_mask=args.do_whole_word_mask, vocab_list=vocab_list)
doc_instances = [json.dumps(instance) for instance in doc_instances]
for instance in doc_instances:
epoch_file.write(instance + '\n')
num_instances += 1
metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json"
with metrics_file.open('w') as metrics_file:
metrics = {
"num_training_examples": num_instances,
"max_seq_len": args.max_seq_len
}
metrics_file.write(json.dumps(metrics))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment