"git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "f6dc2f67783082b433dfa99d4b0a8992ba64be9d"
Commit abb7d1ff authored by Matthew Carrigan's avatar Matthew Carrigan
Browse files

Added proper context management to ensure cleanup happens in the right

order.
parent 06a30cfd
...@@ -23,6 +23,7 @@ class DocumentDatabase: ...@@ -23,6 +23,7 @@ class DocumentDatabase:
self.documents = [] self.documents = []
self.document_shelf = None self.document_shelf = None
self.document_shelf_filepath = None self.document_shelf_filepath = None
self.temp_dir = None
self.doc_lengths = [] self.doc_lengths = []
self.doc_cumsum = None self.doc_cumsum = None
self.cumsum_max = None self.cumsum_max = None
...@@ -68,9 +69,14 @@ class DocumentDatabase: ...@@ -68,9 +69,14 @@ class DocumentDatabase:
else: else:
return self.documents[item] return self.documents[item]
def cleanup(self): def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, traceback):
if self.document_shelf is not None: if self.document_shelf is not None:
self.document_shelf.close() self.document_shelf.close()
if self.temp_dir is not None:
self.temp_dir.cleanup()
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens): def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
...@@ -247,40 +253,39 @@ def main(): ...@@ -247,40 +253,39 @@ def main():
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
vocab_list = list(tokenizer.vocab.keys()) vocab_list = list(tokenizer.vocab.keys())
docs = DocumentDatabase(reduce_memory=args.reduce_memory) with DocumentDatabase(reduce_memory=args.reduce_memory) as docs:
with args.train_corpus.open() as f: with args.train_corpus.open() as f:
doc = [] doc = []
for line in tqdm(f, desc="Loading Dataset", unit=" lines"): for line in tqdm(f, desc="Loading Dataset", unit=" lines"):
line = line.strip() line = line.strip()
if line == "": if line == "":
docs.add_document(doc) docs.add_document(doc)
doc = [] doc = []
else: else:
tokens = tokenizer.tokenize(line) tokens = tokenizer.tokenize(line)
doc.append(tokens) doc.append(tokens)
args.output_dir.mkdir(exist_ok=True) args.output_dir.mkdir(exist_ok=True)
for epoch in trange(args.epochs_to_generate, desc="Epoch"): for epoch in trange(args.epochs_to_generate, desc="Epoch"):
epoch_filename = args.output_dir / f"epoch_{epoch}.json" epoch_filename = args.output_dir / f"epoch_{epoch}.json"
num_instances = 0 num_instances = 0
with epoch_filename.open('w') as epoch_file: with epoch_filename.open('w') as epoch_file:
for doc_idx in trange(len(docs), desc="Document"): for doc_idx in trange(len(docs), desc="Document"):
doc_instances = create_instances_from_document( doc_instances = create_instances_from_document(
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob, docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob,
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq, masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq,
vocab_list=vocab_list) vocab_list=vocab_list)
doc_instances = [json.dumps(instance) for instance in doc_instances] doc_instances = [json.dumps(instance) for instance in doc_instances]
for instance in doc_instances: for instance in doc_instances:
epoch_file.write(instance + '\n') epoch_file.write(instance + '\n')
num_instances += 1 num_instances += 1
metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json" metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json"
with metrics_file.open('w') as metrics_file: with metrics_file.open('w') as metrics_file:
metrics = { metrics = {
"num_training_examples": num_instances, "num_training_examples": num_instances,
"max_seq_len": args.max_seq_len "max_seq_len": args.max_seq_len
} }
metrics_file.write(json.dumps(metrics)) metrics_file.write(json.dumps(metrics))
docs.cleanup()
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment