Commit 1ef41b83 authored by wangfei's avatar wangfei
Browse files

Revert "Fix: save model/model.module"

This reverts commit 00e9c4cc.
parent 00e9c4cc
...@@ -155,12 +155,12 @@ def main(): ...@@ -155,12 +155,12 @@ def main():
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n" "0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n") "Positive power of 2: static loss scaling value.\n")
parser.add_argument("--warmup_steps", parser.add_argument("--warmup_steps",
default=0, default=0,
type=int, type=int,
help="Linear warmup over warmup_steps.") help="Linear warmup over warmup_steps.")
parser.add_argument("--adam_epsilon", parser.add_argument("--adam_epsilon",
default=1e-8, default=1e-8,
type=float, type=float,
help="Epsilon for Adam optimizer.") help="Epsilon for Adam optimizer.")
parser.add_argument("--learning_rate", parser.add_argument("--learning_rate",
...@@ -322,8 +322,7 @@ def main(): ...@@ -322,8 +322,7 @@ def main():
# Save a trained model # Save a trained model
if args.local_rank == -1 or torch.distributed.get_rank() == 0: if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logging.info("** ** * Saving fine-tuned model ** ** * ") logging.info("** ** * Saving fine-tuned model ** ** * ")
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model.save_pretrained(args.output_dir)
model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
......
...@@ -610,8 +610,7 @@ def main(): ...@@ -610,8 +610,7 @@ def main():
# Save a trained model # Save a trained model
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
logger.info("** ** * Saving fine - tuned model ** ** * ") logger.info("** ** * Saving fine - tuned model ** ** * ")
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model.save_pretrained(args.output_dir)
model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment