Unverified Commit 27c81690 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #7 from OceanPang/env

fix checkpoint & runner bugs
parents 923091b5 818c40c3
from .hook import Hook
from .checkpoint_saver import CheckpointSaverHook
from .checkpoint_saver import CheckpointHook
from .closure import ClosureHook
from .lr_updater import LrUpdaterHook
from .optimizer_stepper import OptimizerStepperHook
from .optimizer_stepper import OptimizerHook
from .iter_timer import IterTimerHook
from .logger import *
......@@ -2,7 +2,7 @@ from .hook import Hook
from ..utils import master_only
class CheckpointSaverHook(Hook):
class CheckpointHook(Hook):
def __init__(self,
interval=-1,
......
......@@ -3,7 +3,7 @@ from torch.nn.utils import clip_grad
from .hook import Hook
class OptimizerStepperHook(Hook):
class OptimizerHook(Hook):
def __init__(self, grad_clip=False, max_norm=35, norm_type=2):
self.grad_clip = grad_clip
......
......@@ -8,8 +8,8 @@ from torch.nn.parallel import DataParallel, DistributedDataParallel
from .log_buffer import LogBuffer
from .. import hooks
from ..hooks import (Hook, LrUpdaterHook, CheckpointSaverHook, IterTimerHook,
OptimizerStepperHook)
from ..hooks import (Hook, LrUpdaterHook, CheckpointHook, IterTimerHook,
OptimizerHook)
from ..io import load_checkpoint, save_checkpoint
from ..utils import (get_dist_info, get_host_info, get_time_str,
add_file_handler, obj_from_dict)
......@@ -182,6 +182,16 @@ class Runner(object):
if not inserted:
self._hooks.insert(0, hook)
def build_hook(self, args, hook_type=None):
if isinstance(args, Hook):
return args
elif isinstance(args, dict):
assert issubclass(hook_type, Hook)
return hook_type(**args)
else:
raise TypeError('"args" must be either a Hook object'
' or dict, not {}'.format(type(args)))
def call_hook(self, fn_name):
for hook in self._hooks:
getattr(hook, fn_name)(self)
......@@ -201,7 +211,7 @@ class Runner(object):
else:
meta.update(epoch=self.epoch + 1, iter=self.iter)
filename = osp.join(out_dir, filename_tmpl.format(self.epoch))
filename = osp.join(out_dir, filename_tmpl.format(self.epoch + 1))
linkname = osp.join(out_dir, 'latest.pth')
optimizer = self.optimizer if save_optimizer else None
save_checkpoint(self.model, filename, optimizer=optimizer, meta=meta)
......@@ -213,7 +223,6 @@ class Runner(object):
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(data_loader)
self.call_hook('before_train_epoch')
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
......@@ -330,7 +339,7 @@ class Runner(object):
def register_training_hooks(self,
lr_config,
grad_clip_config=None,
optimizer_config=None,
checkpoint_config=None,
log_config=None):
"""Register default hooks for training.
......@@ -342,13 +351,13 @@ class Runner(object):
- IterTimerHook
- LoggerHook
"""
if grad_clip_config is None:
grad_clip_config = {}
if optimizer_config is None:
optimizer_config = {}
if checkpoint_config is None:
checkpoint_config = {}
self.register_lr_hooks(lr_config)
self.register_hook(OptimizerStepperHook(**grad_clip_config))
self.register_hook(CheckpointSaverHook(**checkpoint_config))
self.register_hook(self.build_hook(optimizer_config, OptimizerHook))
self.register_hook(self.build_hook(checkpoint_config, CheckpointHook))
self.register_hook(IterTimerHook())
if log_config is not None:
self.register_logger_hooks(log_config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment