Commit b8ff5689 authored by wangfei's avatar wangfei
Browse files

Fix bug of multi-gpu training in lm finetuning

parent 9d0029e2
...@@ -320,7 +320,7 @@ def main(): ...@@ -320,7 +320,7 @@ def main():
global_step += 1 global_step += 1
# Save a trained model # Save a trained model
if n_gpu > 1 and torch.distributed.get_rank() == 0 or n_gpu <=1 : if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logging.info("** ** * Saving fine-tuned model ** ** * ") logging.info("** ** * Saving fine-tuned model ** ** * ")
model.save_pretrained(args.output_dir) model.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
......
...@@ -507,7 +507,7 @@ def main(): ...@@ -507,7 +507,7 @@ def main():
if os.path.exists(args.output_dir) and os.listdir(args.output_dir): if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
if not os.path.exists(args.output_dir) and ( n_gpu > 1 and torch.distributed.get_rank() == 0 or n_gpu <=1 ): if not os.path.exists(args.output_dir) and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
os.makedirs(args.output_dir) os.makedirs(args.output_dir)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
...@@ -608,7 +608,7 @@ def main(): ...@@ -608,7 +608,7 @@ def main():
global_step += 1 global_step += 1
# Save a trained model # Save a trained model
if args.do_train and ( n_gpu > 1 and torch.distributed.get_rank() == 0 or n_gpu <=1): if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
logger.info("** ** * Saving fine - tuned model ** ** * ") logger.info("** ** * Saving fine - tuned model ** ** * ")
model.save_pretrained(args.output_dir) model.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment