Unverified Commit 0d5332a4 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

use registry to manage hooks (#199)

parent c2c9fced
# Copyright (c) Open-MMLab. All rights reserved.
from .checkpoint import CheckpointHook
from .closure import ClosureHook
from .hook import Hook
from .hook import HOOKS, Hook
from .iter_timer import IterTimerHook
from .logger import (LoggerHook, TensorboardLoggerHook, TextLoggerHook,
WandbLoggerHook)
......@@ -11,7 +11,7 @@ from .optimizer import OptimizerHook
from .sampler_seed import DistSamplerSeedHook
__all__ = [
'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook',
'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook',
'TextLoggerHook', 'TensorboardLoggerHook', 'WandbLoggerHook'
'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook',
'LoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', 'WandbLoggerHook'
]
# Copyright (c) Open-MMLab. All rights reserved.
from ..dist_utils import master_only
from .hook import Hook
from .hook import HOOKS, Hook
@HOOKS.register_module
class CheckpointHook(Hook):
def __init__(self,
......
# Copyright (c) Open-MMLab. All rights reserved.
from .hook import Hook
from .hook import HOOKS, Hook
@HOOKS.register_module
class ClosureHook(Hook):
def __init__(self, fn_name, fn):
......
# Copyright (c) Open-MMLab. All rights reserved.
from mmcv.utils import Registry
HOOKS = Registry('hook')
class Hook(object):
def before_run(self, runner):
......
# Copyright (c) Open-MMLab. All rights reserved.
import time
from .hook import Hook
from .hook import HOOKS, Hook
@HOOKS.register_module
class IterTimerHook(Hook):
def before_epoch(self, runner):
......
......@@ -3,10 +3,12 @@ import os.path as osp
import torch
from ...dist_utils import master_only
from mmcv.runner import master_only
from ..hook import HOOKS
from .base import LoggerHook
@HOOKS.register_module
class TensorboardLoggerHook(LoggerHook):
def __init__(self,
......
......@@ -7,9 +7,11 @@ import torch
import torch.distributed as dist
import mmcv
from ..hook import HOOKS
from .base import LoggerHook
@HOOKS.register_module
class TextLoggerHook(LoggerHook):
def __init__(self, interval=10, ignore_last=True, reset_flag=False):
......
# Copyright (c) Open-MMLab. All rights reserved.
import numbers
from ...dist_utils import master_only
from mmcv.runner import master_only
from ..hook import HOOKS
from .base import LoggerHook
@HOOKS.register_module
class WandbLoggerHook(LoggerHook):
def __init__(self,
......
......@@ -2,7 +2,7 @@
from __future__ import division
from math import cos, pi
from .hook import Hook
from .hook import HOOKS, Hook
class LrUpdaterHook(Hook):
......@@ -88,6 +88,7 @@ class LrUpdaterHook(Hook):
self._set_lr(runner, warmup_lr)
@HOOKS.register_module
class FixedLrUpdaterHook(LrUpdaterHook):
def __init__(self, **kwargs):
......@@ -97,6 +98,7 @@ class FixedLrUpdaterHook(LrUpdaterHook):
return base_lr
@HOOKS.register_module
class StepLrUpdaterHook(LrUpdaterHook):
def __init__(self, step, gamma=0.1, **kwargs):
......@@ -126,6 +128,7 @@ class StepLrUpdaterHook(LrUpdaterHook):
return base_lr * self.gamma**exp
@HOOKS.register_module
class ExpLrUpdaterHook(LrUpdaterHook):
def __init__(self, gamma, **kwargs):
......@@ -137,6 +140,7 @@ class ExpLrUpdaterHook(LrUpdaterHook):
return base_lr * self.gamma**progress
@HOOKS.register_module
class PolyLrUpdaterHook(LrUpdaterHook):
def __init__(self, power=1., min_lr=0., **kwargs):
......@@ -155,6 +159,7 @@ class PolyLrUpdaterHook(LrUpdaterHook):
return (base_lr - self.min_lr) * coeff + self.min_lr
@HOOKS.register_module
class InvLrUpdaterHook(LrUpdaterHook):
def __init__(self, gamma, power=1., **kwargs):
......@@ -167,6 +172,7 @@ class InvLrUpdaterHook(LrUpdaterHook):
return base_lr * (1 + self.gamma * progress)**(-self.power)
@HOOKS.register_module
class CosineLrUpdaterHook(LrUpdaterHook):
def __init__(self, target_lr=0, **kwargs):
......
# Copyright (c) Open-MMLab. All rights reserved.
import torch
from .hook import Hook
from .hook import HOOKS, Hook
@HOOKS.register_module
class EmptyCacheHook(Hook):
def __init__(self, before_epoch=False, after_epoch=True, after_iter=False):
......
# Copyright (c) Open-MMLab. All rights reserved.
from torch.nn.utils import clip_grad
from .hook import Hook
from .hook import HOOKS, Hook
@HOOKS.register_module
class OptimizerHook(Hook):
def __init__(self, grad_clip=None):
......
# Copyright (c) Open-MMLab. All rights reserved.
from .hook import Hook
from .hook import HOOKS, Hook
@HOOKS.register_module
class DistSamplerSeedHook(Hook):
def before_epoch(self, runner):
......
......@@ -6,11 +6,9 @@ import time
import torch
import mmcv
from . import hooks
from .checkpoint import load_checkpoint, save_checkpoint
from .dist_utils import get_dist_info
from .hooks import (CheckpointHook, Hook, IterTimerHook, LrUpdaterHook,
OptimizerHook, lr_updater)
from .hooks import HOOKS, Hook, IterTimerHook
from .log_buffer import LogBuffer
from .priority import get_priority
from .utils import get_host_info, get_time_str, obj_from_dict
......@@ -223,16 +221,6 @@ class Runner(object):
if not inserted:
self._hooks.insert(0, hook)
def build_hook(self, args, hook_type=None):
if isinstance(args, Hook):
return args
elif isinstance(args, dict):
assert issubclass(hook_type, Hook)
return hook_type(**args)
else:
raise TypeError('"args" must be either a Hook object'
' or dict, not {}'.format(type(args)))
def call_hook(self, fn_name):
for hook in self._hooks:
getattr(hook, fn_name)(self)
......@@ -373,26 +361,41 @@ class Runner(object):
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
def register_lr_hooks(self, lr_config):
if isinstance(lr_config, LrUpdaterHook):
self.register_hook(lr_config)
elif isinstance(lr_config, dict):
def register_lr_hook(self, lr_config):
if isinstance(lr_config, dict):
assert 'policy' in lr_config
# from .hooks import lr_updater
hook_name = lr_config['policy'].title() + 'LrUpdaterHook'
if not hasattr(lr_updater, hook_name):
raise ValueError('"{}" does not exist'.format(hook_name))
hook_cls = getattr(lr_updater, hook_name)
self.register_hook(hook_cls(**lr_config))
hook_type = lr_config.pop('policy').title() + 'LrUpdaterHook'
lr_config['type'] = hook_type
hook = mmcv.build_from_cfg(lr_config, HOOKS)
else:
hook = lr_config
self.register_hook(hook)
def register_optimizer_hook(self, optimizer_config):
if optimizer_config is None:
return
if isinstance(optimizer_config, dict):
optimizer_config.setdefault('type', 'OptimizerHook')
hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
else:
raise TypeError('"lr_config" must be either a LrUpdaterHook object'
' or dict, not {}'.format(type(lr_config)))
hook = optimizer_config
self.register_hook(hook)
def register_checkpoint_hook(self, checkpoint_config):
if checkpoint_config is None:
return
if isinstance(checkpoint_config, dict):
checkpoint_config.setdefault('type', 'CheckpointHook')
hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
else:
hook = checkpoint_config
self.register_hook(hook)
def register_logger_hooks(self, log_config):
log_interval = log_config['interval']
for info in log_config['hooks']:
logger_hook = obj_from_dict(
info, hooks, default_args=dict(interval=log_interval))
logger_hook = mmcv.build_from_cfg(
info, HOOKS, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority='VERY_LOW')
def register_training_hooks(self,
......@@ -410,13 +413,8 @@ class Runner(object):
- IterTimerHook
- LoggerHook(s)
"""
if optimizer_config is None:
optimizer_config = {}
if checkpoint_config is None:
checkpoint_config = {}
self.register_lr_hooks(lr_config)
self.register_hook(self.build_hook(optimizer_config, OptimizerHook))
self.register_hook(self.build_hook(checkpoint_config, CheckpointHook))
self.register_lr_hook(lr_config)
self.register_optimizer_hook(optimizer_config)
self.register_checkpoint_hook(checkpoint_config)
self.register_hook(IterTimerHook())
if log_config is not None:
self.register_logger_hooks(log_config)
self.register_logger_hooks(log_config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment