Unverified Commit 0d5332a4 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

use registry to manage hooks (#199)

parent c2c9fced
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from .checkpoint import CheckpointHook from .checkpoint import CheckpointHook
from .closure import ClosureHook from .closure import ClosureHook
from .hook import Hook from .hook import HOOKS, Hook
from .iter_timer import IterTimerHook from .iter_timer import IterTimerHook
from .logger import (LoggerHook, TensorboardLoggerHook, TextLoggerHook, from .logger import (LoggerHook, TensorboardLoggerHook, TextLoggerHook,
WandbLoggerHook) WandbLoggerHook)
...@@ -11,7 +11,7 @@ from .optimizer import OptimizerHook ...@@ -11,7 +11,7 @@ from .optimizer import OptimizerHook
from .sampler_seed import DistSamplerSeedHook from .sampler_seed import DistSamplerSeedHook
__all__ = [ __all__ = [
'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook', 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook',
'TextLoggerHook', 'TensorboardLoggerHook', 'WandbLoggerHook' 'LoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', 'WandbLoggerHook'
] ]
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from ..dist_utils import master_only from ..dist_utils import master_only
from .hook import Hook from .hook import HOOKS, Hook
@HOOKS.register_module
class CheckpointHook(Hook): class CheckpointHook(Hook):
def __init__(self, def __init__(self,
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from .hook import Hook from .hook import HOOKS, Hook
@HOOKS.register_module
class ClosureHook(Hook): class ClosureHook(Hook):
def __init__(self, fn_name, fn): def __init__(self, fn_name, fn):
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from mmcv.utils import Registry
HOOKS = Registry('hook')
class Hook(object): class Hook(object):
def before_run(self, runner): def before_run(self, runner):
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import time import time
from .hook import Hook from .hook import HOOKS, Hook
@HOOKS.register_module
class IterTimerHook(Hook): class IterTimerHook(Hook):
def before_epoch(self, runner): def before_epoch(self, runner):
......
...@@ -3,10 +3,12 @@ import os.path as osp ...@@ -3,10 +3,12 @@ import os.path as osp
import torch import torch
from ...dist_utils import master_only from mmcv.runner import master_only
from ..hook import HOOKS
from .base import LoggerHook from .base import LoggerHook
@HOOKS.register_module
class TensorboardLoggerHook(LoggerHook): class TensorboardLoggerHook(LoggerHook):
def __init__(self, def __init__(self,
......
...@@ -7,9 +7,11 @@ import torch ...@@ -7,9 +7,11 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import mmcv import mmcv
from ..hook import HOOKS
from .base import LoggerHook from .base import LoggerHook
@HOOKS.register_module
class TextLoggerHook(LoggerHook): class TextLoggerHook(LoggerHook):
def __init__(self, interval=10, ignore_last=True, reset_flag=False): def __init__(self, interval=10, ignore_last=True, reset_flag=False):
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import numbers import numbers
from ...dist_utils import master_only from mmcv.runner import master_only
from ..hook import HOOKS
from .base import LoggerHook from .base import LoggerHook
@HOOKS.register_module
class WandbLoggerHook(LoggerHook): class WandbLoggerHook(LoggerHook):
def __init__(self, def __init__(self,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from __future__ import division from __future__ import division
from math import cos, pi from math import cos, pi
from .hook import Hook from .hook import HOOKS, Hook
class LrUpdaterHook(Hook): class LrUpdaterHook(Hook):
...@@ -88,6 +88,7 @@ class LrUpdaterHook(Hook): ...@@ -88,6 +88,7 @@ class LrUpdaterHook(Hook):
self._set_lr(runner, warmup_lr) self._set_lr(runner, warmup_lr)
@HOOKS.register_module
class FixedLrUpdaterHook(LrUpdaterHook): class FixedLrUpdaterHook(LrUpdaterHook):
def __init__(self, **kwargs): def __init__(self, **kwargs):
...@@ -97,6 +98,7 @@ class FixedLrUpdaterHook(LrUpdaterHook): ...@@ -97,6 +98,7 @@ class FixedLrUpdaterHook(LrUpdaterHook):
return base_lr return base_lr
@HOOKS.register_module
class StepLrUpdaterHook(LrUpdaterHook): class StepLrUpdaterHook(LrUpdaterHook):
def __init__(self, step, gamma=0.1, **kwargs): def __init__(self, step, gamma=0.1, **kwargs):
...@@ -126,6 +128,7 @@ class StepLrUpdaterHook(LrUpdaterHook): ...@@ -126,6 +128,7 @@ class StepLrUpdaterHook(LrUpdaterHook):
return base_lr * self.gamma**exp return base_lr * self.gamma**exp
@HOOKS.register_module
class ExpLrUpdaterHook(LrUpdaterHook): class ExpLrUpdaterHook(LrUpdaterHook):
def __init__(self, gamma, **kwargs): def __init__(self, gamma, **kwargs):
...@@ -137,6 +140,7 @@ class ExpLrUpdaterHook(LrUpdaterHook): ...@@ -137,6 +140,7 @@ class ExpLrUpdaterHook(LrUpdaterHook):
return base_lr * self.gamma**progress return base_lr * self.gamma**progress
@HOOKS.register_module
class PolyLrUpdaterHook(LrUpdaterHook): class PolyLrUpdaterHook(LrUpdaterHook):
def __init__(self, power=1., min_lr=0., **kwargs): def __init__(self, power=1., min_lr=0., **kwargs):
...@@ -155,6 +159,7 @@ class PolyLrUpdaterHook(LrUpdaterHook): ...@@ -155,6 +159,7 @@ class PolyLrUpdaterHook(LrUpdaterHook):
return (base_lr - self.min_lr) * coeff + self.min_lr return (base_lr - self.min_lr) * coeff + self.min_lr
@HOOKS.register_module
class InvLrUpdaterHook(LrUpdaterHook): class InvLrUpdaterHook(LrUpdaterHook):
def __init__(self, gamma, power=1., **kwargs): def __init__(self, gamma, power=1., **kwargs):
...@@ -167,6 +172,7 @@ class InvLrUpdaterHook(LrUpdaterHook): ...@@ -167,6 +172,7 @@ class InvLrUpdaterHook(LrUpdaterHook):
return base_lr * (1 + self.gamma * progress)**(-self.power) return base_lr * (1 + self.gamma * progress)**(-self.power)
@HOOKS.register_module
class CosineLrUpdaterHook(LrUpdaterHook): class CosineLrUpdaterHook(LrUpdaterHook):
def __init__(self, target_lr=0, **kwargs): def __init__(self, target_lr=0, **kwargs):
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import torch import torch
from .hook import Hook from .hook import HOOKS, Hook
@HOOKS.register_module
class EmptyCacheHook(Hook): class EmptyCacheHook(Hook):
def __init__(self, before_epoch=False, after_epoch=True, after_iter=False): def __init__(self, before_epoch=False, after_epoch=True, after_iter=False):
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from torch.nn.utils import clip_grad from torch.nn.utils import clip_grad
from .hook import Hook from .hook import HOOKS, Hook
@HOOKS.register_module
class OptimizerHook(Hook): class OptimizerHook(Hook):
def __init__(self, grad_clip=None): def __init__(self, grad_clip=None):
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from .hook import Hook from .hook import HOOKS, Hook
@HOOKS.register_module
class DistSamplerSeedHook(Hook): class DistSamplerSeedHook(Hook):
def before_epoch(self, runner): def before_epoch(self, runner):
......
...@@ -6,11 +6,9 @@ import time ...@@ -6,11 +6,9 @@ import time
import torch import torch
import mmcv import mmcv
from . import hooks
from .checkpoint import load_checkpoint, save_checkpoint from .checkpoint import load_checkpoint, save_checkpoint
from .dist_utils import get_dist_info from .dist_utils import get_dist_info
from .hooks import (CheckpointHook, Hook, IterTimerHook, LrUpdaterHook, from .hooks import HOOKS, Hook, IterTimerHook
OptimizerHook, lr_updater)
from .log_buffer import LogBuffer from .log_buffer import LogBuffer
from .priority import get_priority from .priority import get_priority
from .utils import get_host_info, get_time_str, obj_from_dict from .utils import get_host_info, get_time_str, obj_from_dict
...@@ -223,16 +221,6 @@ class Runner(object): ...@@ -223,16 +221,6 @@ class Runner(object):
if not inserted: if not inserted:
self._hooks.insert(0, hook) self._hooks.insert(0, hook)
def build_hook(self, args, hook_type=None):
if isinstance(args, Hook):
return args
elif isinstance(args, dict):
assert issubclass(hook_type, Hook)
return hook_type(**args)
else:
raise TypeError('"args" must be either a Hook object'
' or dict, not {}'.format(type(args)))
def call_hook(self, fn_name): def call_hook(self, fn_name):
for hook in self._hooks: for hook in self._hooks:
getattr(hook, fn_name)(self) getattr(hook, fn_name)(self)
...@@ -373,26 +361,41 @@ class Runner(object): ...@@ -373,26 +361,41 @@ class Runner(object):
time.sleep(1) # wait for some hooks like loggers to finish time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run') self.call_hook('after_run')
def register_lr_hooks(self, lr_config): def register_lr_hook(self, lr_config):
if isinstance(lr_config, LrUpdaterHook): if isinstance(lr_config, dict):
self.register_hook(lr_config)
elif isinstance(lr_config, dict):
assert 'policy' in lr_config assert 'policy' in lr_config
# from .hooks import lr_updater hook_type = lr_config.pop('policy').title() + 'LrUpdaterHook'
hook_name = lr_config['policy'].title() + 'LrUpdaterHook' lr_config['type'] = hook_type
if not hasattr(lr_updater, hook_name): hook = mmcv.build_from_cfg(lr_config, HOOKS)
raise ValueError('"{}" does not exist'.format(hook_name)) else:
hook_cls = getattr(lr_updater, hook_name) hook = lr_config
self.register_hook(hook_cls(**lr_config)) self.register_hook(hook)
def register_optimizer_hook(self, optimizer_config):
if optimizer_config is None:
return
if isinstance(optimizer_config, dict):
optimizer_config.setdefault('type', 'OptimizerHook')
hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
else: else:
raise TypeError('"lr_config" must be either a LrUpdaterHook object' hook = optimizer_config
' or dict, not {}'.format(type(lr_config))) self.register_hook(hook)
def register_checkpoint_hook(self, checkpoint_config):
if checkpoint_config is None:
return
if isinstance(checkpoint_config, dict):
checkpoint_config.setdefault('type', 'CheckpointHook')
hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
else:
hook = checkpoint_config
self.register_hook(hook)
def register_logger_hooks(self, log_config): def register_logger_hooks(self, log_config):
log_interval = log_config['interval'] log_interval = log_config['interval']
for info in log_config['hooks']: for info in log_config['hooks']:
logger_hook = obj_from_dict( logger_hook = mmcv.build_from_cfg(
info, hooks, default_args=dict(interval=log_interval)) info, HOOKS, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority='VERY_LOW') self.register_hook(logger_hook, priority='VERY_LOW')
def register_training_hooks(self, def register_training_hooks(self,
...@@ -410,13 +413,8 @@ class Runner(object): ...@@ -410,13 +413,8 @@ class Runner(object):
- IterTimerHook - IterTimerHook
- LoggerHook(s) - LoggerHook(s)
""" """
if optimizer_config is None: self.register_lr_hook(lr_config)
optimizer_config = {} self.register_optimizer_hook(optimizer_config)
if checkpoint_config is None: self.register_checkpoint_hook(checkpoint_config)
checkpoint_config = {}
self.register_lr_hooks(lr_config)
self.register_hook(self.build_hook(optimizer_config, OptimizerHook))
self.register_hook(self.build_hook(checkpoint_config, CheckpointHook))
self.register_hook(IterTimerHook()) self.register_hook(IterTimerHook())
if log_config is not None: self.register_logger_hooks(log_config)
self.register_logger_hooks(log_config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment