Unverified Commit 34b552b8 authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

[Feature] Support save last checkpoint (#853)

* [Feature] Support save last checkpoint

* move to before_run, update doc

* move to after train

* add comments
parent e5eaf2a7
...@@ -24,6 +24,8 @@ class CheckpointHook(Hook): ...@@ -24,6 +24,8 @@ class CheckpointHook(Hook):
In some cases we want only the latest few checkpoints and would In some cases we want only the latest few checkpoints and would
like to delete old ones to save the disk space. like to delete old ones to save the disk space.
Default: -1, which means unlimited. Default: -1, which means unlimited.
save_last (bool): Whether to force the last checkpoint to be saved
regardless of interval.
sync_buffer (bool): Whether to synchronize buffers in different sync_buffer (bool): Whether to synchronize buffers in different
gpus. Default: False. gpus. Default: False.
""" """
...@@ -34,6 +36,7 @@ class CheckpointHook(Hook): ...@@ -34,6 +36,7 @@ class CheckpointHook(Hook):
save_optimizer=True, save_optimizer=True,
out_dir=None, out_dir=None,
max_keep_ckpts=-1, max_keep_ckpts=-1,
save_last=True,
sync_buffer=False, sync_buffer=False,
**kwargs): **kwargs):
self.interval = interval self.interval = interval
...@@ -41,14 +44,26 @@ class CheckpointHook(Hook): ...@@ -41,14 +44,26 @@ class CheckpointHook(Hook):
self.save_optimizer = save_optimizer self.save_optimizer = save_optimizer
self.out_dir = out_dir self.out_dir = out_dir
self.max_keep_ckpts = max_keep_ckpts self.max_keep_ckpts = max_keep_ckpts
self.save_last = save_last
self.args = kwargs self.args = kwargs
self.sync_buffer = sync_buffer self.sync_buffer = sync_buffer
def before_run(self, runner):
if not self.out_dir:
self.out_dir = runner.work_dir
def after_train_epoch(self, runner): def after_train_epoch(self, runner):
if not self.by_epoch or not self.every_n_epochs(runner, self.interval): if not self.by_epoch:
return return
runner.logger.info(f'Saving checkpoint at {runner.epoch + 1} epochs') # save checkpoint for following cases:
# 1. every ``self.interval`` epochs
# 2. reach the last epoch of training
if self.every_n_epochs(
runner, self.interval) or (self.save_last
and self.is_last_epoch(runner)):
runner.logger.info(
f'Saving checkpoint at {runner.epoch + 1} epochs')
if self.sync_buffer: if self.sync_buffer:
allreduce_params(runner.model.buffers()) allreduce_params(runner.model.buffers())
self._save_checkpoint(runner) self._save_checkpoint(runner)
...@@ -56,8 +71,6 @@ class CheckpointHook(Hook): ...@@ -56,8 +71,6 @@ class CheckpointHook(Hook):
@master_only @master_only
def _save_checkpoint(self, runner): def _save_checkpoint(self, runner):
"""Save the current checkpoint and delete unwanted checkpoint.""" """Save the current checkpoint and delete unwanted checkpoint."""
if not self.out_dir:
self.out_dir = runner.work_dir
runner.save_checkpoint( runner.save_checkpoint(
self.out_dir, save_optimizer=self.save_optimizer, **self.args) self.out_dir, save_optimizer=self.save_optimizer, **self.args)
if runner.meta is not None: if runner.meta is not None:
...@@ -91,9 +104,15 @@ class CheckpointHook(Hook): ...@@ -91,9 +104,15 @@ class CheckpointHook(Hook):
break break
def after_train_iter(self, runner): def after_train_iter(self, runner):
if self.by_epoch or not self.every_n_iters(runner, self.interval): if self.by_epoch:
return return
# save checkpoint for following cases:
# 1. every ``self.interval`` iterations
# 2. reach the last iteration of training
if self.every_n_iters(
runner, self.interval) or (self.save_last
and self.is_last_iter(runner)):
runner.logger.info( runner.logger.info(
f'Saving checkpoint at {runner.iter + 1} iterations') f'Saving checkpoint at {runner.iter + 1} iterations')
if self.sync_buffer: if self.sync_buffer:
......
...@@ -59,3 +59,9 @@ class Hook: ...@@ -59,3 +59,9 @@ class Hook:
def end_of_epoch(self, runner): def end_of_epoch(self, runner):
return runner.inner_iter + 1 == len(runner.data_loader) return runner.inner_iter + 1 == len(runner.data_loader)
def is_last_epoch(self, runner):
return runner.epoch + 1 == runner._max_epochs
def is_last_iter(self, runner):
return runner.iter + 1 == runner._max_iters
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment