Commit 528d3f32 authored by jinoobaek-qz's avatar jinoobaek-qz Committed by Lysandre Debut
Browse files

Improve readability and improve make less assumptions about checkpoint format

parent 56301bd9
...@@ -106,15 +106,22 @@ def set_seed(args): ...@@ -106,15 +106,22 @@ def set_seed(args):
torch.cuda.manual_seed_all(args.seed) torch.cuda.manual_seed_all(args.seed)
def rotate_checkpoints(args): def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False):
if args.save_total_limit and args.save_total_limit > 0: if not args.save_total_limit:
return
if args.save_total_limit <= 0:
return
# Check if we should delete older checkpoint(s) # Check if we should delete older checkpoint(s)
glob_checkpoints = glob.glob(os.path.join(args.output_dir, 'checkpoint-*')) glob_checkpoints = glob.glob(os.path.join(args.output_dir, '{}-*'.format(checkpoint_prefix)))
if len(glob_checkpoints) > args.save_total_limit: if len(glob_checkpoints) > args.save_total_limit:
checkpoints_sorted = [] checkpoints_sorted = []
for path in glob_checkpoints: for path in glob_checkpoints:
regex_match = re.match('.*checkpoint-([0-9]+)', path) regex_match = re.match('.*{}-([0-9]+)'.format(checkpoint_prefix), path)
if regex_match and regex_match.groups(): if regex_match and regex_match.groups():
if use_mtime:
checkpoints_sorted.append((os.path.getmtime(path), path))
else:
checkpoints_sorted.append((int(regex_match.groups()[0]), path)) checkpoints_sorted.append((int(regex_match.groups()[0]), path))
checkpoints_sorted = sorted(checkpoints_sorted) checkpoints_sorted = sorted(checkpoints_sorted)
...@@ -244,8 +251,9 @@ def train(args, train_dataset, model, tokenizer): ...@@ -244,8 +251,9 @@ def train(args, train_dataset, model, tokenizer):
logging_loss = tr_loss logging_loss = tr_loss
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
checkpoint_prefix = 'checkpoint'
# Save model checkpoint # Save model checkpoint
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) output_dir = os.path.join(args.output_dir, '{}-{}'.format(checkpoint_prefix, global_step))
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
...@@ -253,7 +261,7 @@ def train(args, train_dataset, model, tokenizer): ...@@ -253,7 +261,7 @@ def train(args, train_dataset, model, tokenizer):
torch.save(args, os.path.join(output_dir, 'training_args.bin')) torch.save(args, os.path.join(output_dir, 'training_args.bin'))
logger.info("Saving model checkpoint to %s", output_dir) logger.info("Saving model checkpoint to %s", output_dir)
rotate_checkpoints(args) _rotate_checkpoints(args, checkpoint_prefix)
if args.max_steps > 0 and global_step > args.max_steps: if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close() epoch_iterator.close()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment