Commit abb7d1ff authored by Matthew Carrigan's avatar Matthew Carrigan
Browse files

Added proper context management to ensure cleanup happens in the right

order.
parent 06a30cfd
...@@ -23,6 +23,7 @@ class DocumentDatabase: ...@@ -23,6 +23,7 @@ class DocumentDatabase:
self.documents = [] self.documents = []
self.document_shelf = None self.document_shelf = None
self.document_shelf_filepath = None self.document_shelf_filepath = None
self.temp_dir = None
self.doc_lengths = [] self.doc_lengths = []
self.doc_cumsum = None self.doc_cumsum = None
self.cumsum_max = None self.cumsum_max = None
...@@ -68,9 +69,14 @@ class DocumentDatabase: ...@@ -68,9 +69,14 @@ class DocumentDatabase:
else: else:
return self.documents[item] return self.documents[item]
def cleanup(self): def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, traceback):
if self.document_shelf is not None: if self.document_shelf is not None:
self.document_shelf.close() self.document_shelf.close()
if self.temp_dir is not None:
self.temp_dir.cleanup()
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens): def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
...@@ -247,40 +253,39 @@ def main(): ...@@ -247,40 +253,39 @@ def main():
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
vocab_list = list(tokenizer.vocab.keys()) vocab_list = list(tokenizer.vocab.keys())
docs = DocumentDatabase(reduce_memory=args.reduce_memory) with DocumentDatabase(reduce_memory=args.reduce_memory) as docs:
with args.train_corpus.open() as f: with args.train_corpus.open() as f:
doc = [] doc = []
for line in tqdm(f, desc="Loading Dataset", unit=" lines"): for line in tqdm(f, desc="Loading Dataset", unit=" lines"):
line = line.strip() line = line.strip()
if line == "": if line == "":
docs.add_document(doc) docs.add_document(doc)
doc = [] doc = []
else: else:
tokens = tokenizer.tokenize(line) tokens = tokenizer.tokenize(line)
doc.append(tokens) doc.append(tokens)
args.output_dir.mkdir(exist_ok=True) args.output_dir.mkdir(exist_ok=True)
for epoch in trange(args.epochs_to_generate, desc="Epoch"): for epoch in trange(args.epochs_to_generate, desc="Epoch"):
epoch_filename = args.output_dir / f"epoch_{epoch}.json" epoch_filename = args.output_dir / f"epoch_{epoch}.json"
num_instances = 0 num_instances = 0
with epoch_filename.open('w') as epoch_file: with epoch_filename.open('w') as epoch_file:
for doc_idx in trange(len(docs), desc="Document"): for doc_idx in trange(len(docs), desc="Document"):
doc_instances = create_instances_from_document( doc_instances = create_instances_from_document(
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob, docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob,
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq, masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq,
vocab_list=vocab_list) vocab_list=vocab_list)
doc_instances = [json.dumps(instance) for instance in doc_instances] doc_instances = [json.dumps(instance) for instance in doc_instances]
for instance in doc_instances: for instance in doc_instances:
epoch_file.write(instance + '\n') epoch_file.write(instance + '\n')
num_instances += 1 num_instances += 1
metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json" metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json"
with metrics_file.open('w') as metrics_file: with metrics_file.open('w') as metrics_file:
metrics = { metrics = {
"num_training_examples": num_instances, "num_training_examples": num_instances,
"max_seq_len": args.max_seq_len "max_seq_len": args.max_seq_len
} }
metrics_file.write(json.dumps(metrics)) metrics_file.write(json.dumps(metrics))
docs.cleanup()
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment