Unverified Commit 34b552b8 authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

[Feature] Support save last checkpoint (#853)

* [Feature] Support save last checkpoint

* move to before_run, update doc

* move to after train

* add comments
parent e5eaf2a7
......@@ -24,6 +24,8 @@ class CheckpointHook(Hook):
In some cases we want only the latest few checkpoints and would
like to delete old ones to save the disk space.
Default: -1, which means unlimited.
save_last (bool): Whether to force the last checkpoint to be saved
regardless of interval.
sync_buffer (bool): Whether to synchronize buffers in different
gpus. Default: False.
"""
......@@ -34,6 +36,7 @@ class CheckpointHook(Hook):
save_optimizer=True,
out_dir=None,
max_keep_ckpts=-1,
save_last=True,
sync_buffer=False,
**kwargs):
self.interval = interval
......@@ -41,23 +44,33 @@ class CheckpointHook(Hook):
self.save_optimizer = save_optimizer
self.out_dir = out_dir
self.max_keep_ckpts = max_keep_ckpts
self.save_last = save_last
self.args = kwargs
self.sync_buffer = sync_buffer
def before_run(self, runner):
if not self.out_dir:
self.out_dir = runner.work_dir
def after_train_epoch(self, runner):
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
if not self.by_epoch:
return
runner.logger.info(f'Saving checkpoint at {runner.epoch + 1} epochs')
if self.sync_buffer:
allreduce_params(runner.model.buffers())
self._save_checkpoint(runner)
# save checkpoint for following cases:
# 1. every ``self.interval`` epochs
# 2. reach the last epoch of training
if self.every_n_epochs(
runner, self.interval) or (self.save_last
and self.is_last_epoch(runner)):
runner.logger.info(
f'Saving checkpoint at {runner.epoch + 1} epochs')
if self.sync_buffer:
allreduce_params(runner.model.buffers())
self._save_checkpoint(runner)
@master_only
def _save_checkpoint(self, runner):
"""Save the current checkpoint and delete unwanted checkpoint."""
if not self.out_dir:
self.out_dir = runner.work_dir
runner.save_checkpoint(
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
if runner.meta is not None:
......@@ -91,11 +104,17 @@ class CheckpointHook(Hook):
break
def after_train_iter(self, runner):
if self.by_epoch or not self.every_n_iters(runner, self.interval):
if self.by_epoch:
return
runner.logger.info(
f'Saving checkpoint at {runner.iter + 1} iterations')
if self.sync_buffer:
allreduce_params(runner.model.buffers())
self._save_checkpoint(runner)
# save checkpoint for following cases:
# 1. every ``self.interval`` iterations
# 2. reach the last iteration of training
if self.every_n_iters(
runner, self.interval) or (self.save_last
and self.is_last_iter(runner)):
runner.logger.info(
f'Saving checkpoint at {runner.iter + 1} iterations')
if self.sync_buffer:
allreduce_params(runner.model.buffers())
self._save_checkpoint(runner)
......@@ -59,3 +59,9 @@ class Hook:
def end_of_epoch(self, runner):
return runner.inner_iter + 1 == len(runner.data_loader)
def is_last_epoch(self, runner):
return runner.epoch + 1 == runner._max_epochs
def is_last_iter(self, runner):
return runner.iter + 1 == runner._max_iters
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment