"references/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "7d1cd1de34298ccb4a993d7c6af59b08cf8ac094"
Unverified Commit 27c81690 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #7 from OceanPang/env

fix checkpoint & runner bugs
parents 923091b5 818c40c3
from .hook import Hook from .hook import Hook
from .checkpoint_saver import CheckpointSaverHook from .checkpoint_saver import CheckpointHook
from .closure import ClosureHook from .closure import ClosureHook
from .lr_updater import LrUpdaterHook from .lr_updater import LrUpdaterHook
from .optimizer_stepper import OptimizerStepperHook from .optimizer_stepper import OptimizerHook
from .iter_timer import IterTimerHook from .iter_timer import IterTimerHook
from .logger import * from .logger import *
...@@ -2,7 +2,7 @@ from .hook import Hook ...@@ -2,7 +2,7 @@ from .hook import Hook
from ..utils import master_only from ..utils import master_only
class CheckpointSaverHook(Hook): class CheckpointHook(Hook):
def __init__(self, def __init__(self,
interval=-1, interval=-1,
......
...@@ -3,7 +3,7 @@ from torch.nn.utils import clip_grad ...@@ -3,7 +3,7 @@ from torch.nn.utils import clip_grad
from .hook import Hook from .hook import Hook
class OptimizerStepperHook(Hook): class OptimizerHook(Hook):
def __init__(self, grad_clip=False, max_norm=35, norm_type=2): def __init__(self, grad_clip=False, max_norm=35, norm_type=2):
self.grad_clip = grad_clip self.grad_clip = grad_clip
......
...@@ -8,8 +8,8 @@ from torch.nn.parallel import DataParallel, DistributedDataParallel ...@@ -8,8 +8,8 @@ from torch.nn.parallel import DataParallel, DistributedDataParallel
from .log_buffer import LogBuffer from .log_buffer import LogBuffer
from .. import hooks from .. import hooks
from ..hooks import (Hook, LrUpdaterHook, CheckpointSaverHook, IterTimerHook, from ..hooks import (Hook, LrUpdaterHook, CheckpointHook, IterTimerHook,
OptimizerStepperHook) OptimizerHook)
from ..io import load_checkpoint, save_checkpoint from ..io import load_checkpoint, save_checkpoint
from ..utils import (get_dist_info, get_host_info, get_time_str, from ..utils import (get_dist_info, get_host_info, get_time_str,
add_file_handler, obj_from_dict) add_file_handler, obj_from_dict)
...@@ -182,6 +182,16 @@ class Runner(object): ...@@ -182,6 +182,16 @@ class Runner(object):
if not inserted: if not inserted:
self._hooks.insert(0, hook) self._hooks.insert(0, hook)
def build_hook(self, args, hook_type=None):
if isinstance(args, Hook):
return args
elif isinstance(args, dict):
assert issubclass(hook_type, Hook)
return hook_type(**args)
else:
raise TypeError('"args" must be either a Hook object'
' or dict, not {}'.format(type(args)))
def call_hook(self, fn_name): def call_hook(self, fn_name):
for hook in self._hooks: for hook in self._hooks:
getattr(hook, fn_name)(self) getattr(hook, fn_name)(self)
...@@ -201,7 +211,7 @@ class Runner(object): ...@@ -201,7 +211,7 @@ class Runner(object):
else: else:
meta.update(epoch=self.epoch + 1, iter=self.iter) meta.update(epoch=self.epoch + 1, iter=self.iter)
filename = osp.join(out_dir, filename_tmpl.format(self.epoch)) filename = osp.join(out_dir, filename_tmpl.format(self.epoch + 1))
linkname = osp.join(out_dir, 'latest.pth') linkname = osp.join(out_dir, 'latest.pth')
optimizer = self.optimizer if save_optimizer else None optimizer = self.optimizer if save_optimizer else None
save_checkpoint(self.model, filename, optimizer=optimizer, meta=meta) save_checkpoint(self.model, filename, optimizer=optimizer, meta=meta)
...@@ -213,7 +223,6 @@ class Runner(object): ...@@ -213,7 +223,6 @@ class Runner(object):
self.data_loader = data_loader self.data_loader = data_loader
self._max_iters = self._max_epochs * len(data_loader) self._max_iters = self._max_epochs * len(data_loader)
self.call_hook('before_train_epoch') self.call_hook('before_train_epoch')
for i, data_batch in enumerate(data_loader): for i, data_batch in enumerate(data_loader):
self._inner_iter = i self._inner_iter = i
self.call_hook('before_train_iter') self.call_hook('before_train_iter')
...@@ -330,7 +339,7 @@ class Runner(object): ...@@ -330,7 +339,7 @@ class Runner(object):
def register_training_hooks(self, def register_training_hooks(self,
lr_config, lr_config,
grad_clip_config=None, optimizer_config=None,
checkpoint_config=None, checkpoint_config=None,
log_config=None): log_config=None):
"""Register default hooks for training. """Register default hooks for training.
...@@ -342,13 +351,13 @@ class Runner(object): ...@@ -342,13 +351,13 @@ class Runner(object):
- IterTimerHook - IterTimerHook
- LoggerHook - LoggerHook
""" """
if grad_clip_config is None: if optimizer_config is None:
grad_clip_config = {} optimizer_config = {}
if checkpoint_config is None: if checkpoint_config is None:
checkpoint_config = {} checkpoint_config = {}
self.register_lr_hooks(lr_config) self.register_lr_hooks(lr_config)
self.register_hook(OptimizerStepperHook(**grad_clip_config)) self.register_hook(self.build_hook(optimizer_config, OptimizerHook))
self.register_hook(CheckpointSaverHook(**checkpoint_config)) self.register_hook(self.build_hook(checkpoint_config, CheckpointHook))
self.register_hook(IterTimerHook()) self.register_hook(IterTimerHook())
if log_config is not None: if log_config is not None:
self.register_logger_hooks(log_config) self.register_logger_hooks(log_config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment