"vscode:/vscode.git/clone" did not exist on "70f429a419dc0bab700b576beaf1f6f0b02da40a"
Commit d6c54697 authored by jinoobaek-qz's avatar jinoobaek-qz Committed by Lysandre Debut
Browse files

Delete older checkpoint after saving new checkpoint

parent 54a31f50
...@@ -224,10 +224,19 @@ def train(args, train_dataset, model, tokenizer): ...@@ -224,10 +224,19 @@ def train(args, train_dataset, model, tokenizer):
logging_loss = tr_loss logging_loss = tr_loss
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
# Save model checkpoint
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
logger.info("Saving model checkpoint to %s", output_dir)
if args.save_total_limit and args.save_total_limit > 0: if args.save_total_limit and args.save_total_limit > 0:
# Check if we should delete older checkpoint(s) # Check if we should delete older checkpoint(s)
glob_checkpoints = glob.glob(os.path.join(args.output_dir, 'checkpoint-*')) glob_checkpoints = glob.glob(os.path.join(args.output_dir, 'checkpoint-*'))
if len(glob_checkpoints) + 1 > args.save_total_limit: if len(glob_checkpoints) > args.save_total_limit:
checkpoints_sorted = [] checkpoints_sorted = []
for path in glob_checkpoints: for path in glob_checkpoints:
regex_match = re.match('.*checkpoint-([0-9]+)', path) regex_match = re.match('.*checkpoint-([0-9]+)', path)
...@@ -236,21 +245,12 @@ def train(args, train_dataset, model, tokenizer): ...@@ -236,21 +245,12 @@ def train(args, train_dataset, model, tokenizer):
checkpoints_sorted = sorted(checkpoints_sorted) checkpoints_sorted = sorted(checkpoints_sorted)
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) + 1 - args.save_total_limit) number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
for checkpoint in checkpoints_to_be_deleted: for checkpoint in checkpoints_to_be_deleted:
logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint)) logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
shutil.rmtree(checkpoint) shutil.rmtree(checkpoint)
# Save model checkpoint
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
logger.info("Saving model checkpoint to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps: if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close() epoch_iterator.close()
break break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment