Unverified Commit 8ac858b1 authored by Bin Zhang's avatar Bin Zhang Committed by GitHub
Browse files

add max_keep_ckpts parameter (#227)



* add max_keep_ckpts parameters to save memory when saving models

* format

* format

* format

* fixed linting error
Co-authored-by: default avatarz-bingo <z-bingo@outlook.com>
parent 732c3797
# Copyright (c) Open-MMLab. All rights reserved.
import os
from ..dist_utils import master_only
from .hook import HOOKS, Hook
......@@ -10,10 +12,12 @@ class CheckpointHook(Hook):
interval=-1,
save_optimizer=True,
out_dir=None,
max_keep_ckpts=-1,
**kwargs):
self.interval = interval
self.save_optimizer = save_optimizer
self.out_dir = out_dir
self.max_keep_ckpts = max_keep_ckpts
self.args = kwargs
@master_only
......@@ -25,3 +29,15 @@ class CheckpointHook(Hook):
self.out_dir = runner.work_dir
runner.save_checkpoint(
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
# remove other checkpoints
if self.max_keep_ckpts > 0:
filename_tmpl = self.args.get('filename_tmpl', 'epoch_{}.pth')
current_epoch = runner.epoch + 1
for epoch in range(current_epoch - self.max_keep_ckpts, 0, -1):
ckpt_path = os.path.join(self.out_dir,
filename_tmpl.format(epoch))
if os.path.exists(ckpt_path):
os.remove(ckpt_path)
else:
break
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment