"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ec07da65e25562040581febaf9b400a462962961"
Unverified Commit a515caa3 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix checkpoint deletion (#11748)

parent b88e0e01
...@@ -1523,10 +1523,6 @@ class Trainer: ...@@ -1523,10 +1523,6 @@ class Trainer:
if self.is_world_process_zero(): if self.is_world_process_zero():
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json")) self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
# Maybe delete some older checkpoints.
if self.is_world_process_zero():
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
# Save RNG state in non-distributed training # Save RNG state in non-distributed training
rng_states = { rng_states = {
"python": random.getstate(), "python": random.getstate(),
...@@ -1552,6 +1548,10 @@ class Trainer: ...@@ -1552,6 +1548,10 @@ class Trainer:
else: else:
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth")) torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))
# Maybe delete some older checkpoints.
if self.is_world_process_zero():
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
def _load_optimizer_and_scheduler(self, checkpoint): def _load_optimizer_and_scheduler(self, checkpoint):
"""If optimizer and scheduler states exist, load them.""" """If optimizer and scheduler states exist, load them."""
if checkpoint is None: if checkpoint is None:
...@@ -1924,7 +1924,7 @@ class Trainer: ...@@ -1924,7 +1924,7 @@ class Trainer:
ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
else: else:
regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path) regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
if regex_match and regex_match.groups(): if regex_match is not None and regex_match.groups() is not None:
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
checkpoints_sorted = sorted(ordering_and_checkpoint_path) checkpoints_sorted = sorted(ordering_and_checkpoint_path)
...@@ -1932,10 +1932,8 @@ class Trainer: ...@@ -1932,10 +1932,8 @@ class Trainer:
# Make sure we don't delete the best model. # Make sure we don't delete the best model.
if self.state.best_model_checkpoint is not None: if self.state.best_model_checkpoint is not None:
best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint))) best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
checkpoints_sorted[best_model_index], checkpoints_sorted[-1] = ( for i in range(best_model_index, len(checkpoints_sorted) - 2):
checkpoints_sorted[-1], checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
checkpoints_sorted[best_model_index],
)
return checkpoints_sorted return checkpoints_sorted
def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None: def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
...@@ -1947,7 +1945,17 @@ class Trainer: ...@@ -1947,7 +1945,17 @@ class Trainer:
if len(checkpoints_sorted) <= self.args.save_total_limit: if len(checkpoints_sorted) <= self.args.save_total_limit:
return return
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit) # If save_total_limit=1 with load_best_mode_at_end=True, we could end up deleting the last checkpoint, which
# we don't do to allow resuming.
save_total_limit = self.args.save_total_limit
if (
self.state.best_model_checkpoint is not None
and self.args.save_total_limit == 1
and checkpoints_sorted[-1] != self.state.best_model_checkpoint
):
save_total_limit = 2
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
for checkpoint in checkpoints_to_be_deleted: for checkpoint in checkpoints_to_be_deleted:
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
......
...@@ -21,6 +21,7 @@ import random ...@@ -21,6 +21,7 @@ import random
import re import re
import tempfile import tempfile
import unittest import unittest
from pathlib import Path
import numpy as np import numpy as np
...@@ -45,6 +46,7 @@ from transformers.testing_utils import ( ...@@ -45,6 +46,7 @@ from transformers.testing_utils import (
require_torch_multi_gpu, require_torch_multi_gpu,
slow, slow,
) )
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils.hp_naming import TrialShortNamer from transformers.utils.hp_naming import TrialShortNamer
...@@ -1048,6 +1050,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -1048,6 +1050,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train() trainer.train()
self.assertTrue(isinstance(trainer.state.total_flos, float)) self.assertTrue(isinstance(trainer.state.total_flos, float))
def check_checkpoint_deletion(self, trainer, output_dir, expected):
# Make fake checkpoints
for n in [5, 10, 15, 20, 25]:
os.makedirs(os.path.join(output_dir, f"{PREFIX_CHECKPOINT_DIR}-{n}"), exist_ok=True)
trainer._rotate_checkpoints(output_dir=output_dir)
glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{PREFIX_CHECKPOINT_DIR}-*")]
values = [int(re.match(f".*{PREFIX_CHECKPOINT_DIR}-([0-9]+)", d).groups()[0]) for d in glob_checkpoints]
self.assertSetEqual(set(values), set(expected))
def test_checkpoint_rotation(self):
with tempfile.TemporaryDirectory() as tmp_dir:
# Without best model at end
trainer = get_regression_trainer(output_dir=tmp_dir, save_total_limit=2)
self.check_checkpoint_deletion(trainer, tmp_dir, [20, 25])
# With best model at end
trainer = get_regression_trainer(output_dir=tmp_dir, load_best_model_at_end=True, save_total_limit=2)
trainer.state.best_model_checkpoint = os.path.join(tmp_dir, "checkpoint-5")
self.check_checkpoint_deletion(trainer, tmp_dir, [5, 25])
# Edge case: we don't always honor save_total_limit=1 if load_best_model_at_end=True to be able to resume
# from checkpoint
trainer = get_regression_trainer(output_dir=tmp_dir, load_best_model_at_end=True, save_total_limit=1)
trainer.state.best_model_checkpoint = os.path.join(tmp_dir, "checkpoint-25")
self.check_checkpoint_deletion(trainer, tmp_dir, [25])
trainer.state.best_model_checkpoint = os.path.join(tmp_dir, "checkpoint-5")
self.check_checkpoint_deletion(trainer, tmp_dir, [5, 25])
def check_mem_metrics(self, trainer, check_func): def check_mem_metrics(self, trainer, check_func):
metrics = trainer.train().metrics metrics = trainer.train().metrics
check_func("init_mem_cpu_alloc_delta", metrics) check_func("init_mem_cpu_alloc_delta", metrics)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment