Commit 54a31f50 authored by jinoobaek-qz's avatar jinoobaek-qz Committed by Lysandre Debut
Browse files

Add save_total_limit

parent 1c507995
...@@ -27,6 +27,8 @@ import logging ...@@ -27,6 +27,8 @@ import logging
import os import os
import pickle import pickle
import random import random
import re
import shutil
import numpy as np import numpy as np
import torch import torch
...@@ -222,6 +224,24 @@ def train(args, train_dataset, model, tokenizer): ...@@ -222,6 +224,24 @@ def train(args, train_dataset, model, tokenizer):
logging_loss = tr_loss logging_loss = tr_loss
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
if args.save_total_limit and args.save_total_limit > 0:
# Check if we should delete older checkpoint(s)
glob_checkpoints = glob.glob(os.path.join(args.output_dir, 'checkpoint-*'))
if len(glob_checkpoints) + 1 > args.save_total_limit:
checkpoints_sorted = []
for path in glob_checkpoints:
regex_match = re.match('.*checkpoint-([0-9]+)', path)
if regex_match and regex_match.groups():
checkpoints_sorted.append((int(regex_match.groups()[0]), path))
checkpoints_sorted = sorted(checkpoints_sorted)
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) + 1 - args.save_total_limit)
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
for checkpoint in checkpoints_to_be_deleted:
logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
shutil.rmtree(checkpoint)
# Save model checkpoint # Save model checkpoint
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
...@@ -359,6 +379,8 @@ def main(): ...@@ -359,6 +379,8 @@ def main():
help="Log every X updates steps.") help="Log every X updates steps.")
parser.add_argument('--save_steps', type=int, default=50, parser.add_argument('--save_steps', type=int, default=50,
help="Save checkpoint every X updates steps.") help="Save checkpoint every X updates steps.")
parser.add_argument('--save_total_limit', type=int, default=None,
help='Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default')
parser.add_argument("--eval_all_checkpoints", action='store_true', parser.add_argument("--eval_all_checkpoints", action='store_true',
help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number") help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
parser.add_argument("--no_cuda", action='store_true', parser.add_argument("--no_cuda", action='store_true',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment