Commit c53cc018 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

[Trainer] Fix _rotate_checkpoints

Close #3920
parent cbbb3c43
......@@ -434,13 +434,13 @@ class Trainer:
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
ordering_and_checkpoint_path = []
glob_checkpoints = Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")
glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")]
for path in glob_checkpoints:
if use_mtime:
ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
else:
regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path)
regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
if regex_match and regex_match.groups():
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
......@@ -449,9 +449,7 @@ class Trainer:
return checkpoints_sorted
def _rotate_checkpoints(self, use_mtime=False) -> None:
if not self.args.save_total_limit:
return
if self.args.save_total_limit <= 0:
if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
return
# Check if we should delete older checkpoint(s)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment