Unverified Commit 15bcaa9c authored by Ma Zerun's avatar Ma Zerun Committed by GitHub
Browse files

Add custom hook by config file (#970)

* Assign different priority to default hooks, and add custom hook register in base runner.

* Add custom hook register in example train file

* Add unittest of custom hook

* Code format
parent 9b8dd083
......@@ -159,7 +159,8 @@ def main():
lr_config=cfg.lr_config,
optimizer_config=cfg.optimizer_config,
checkpoint_config=cfg.checkpoint_config,
log_config=cfg.log_config)
log_config=cfg.log_config,
custom_hooks_config=cfg.get('custom_train_hooks', None))
if dist:
runner.register_hook(DistSamplerSeedHook())
......
......@@ -391,7 +391,7 @@ class BaseRunner(metaclass=ABCMeta):
hook = mmcv.build_from_cfg(lr_config, HOOKS)
else:
hook = lr_config
self.register_hook(hook)
self.register_hook(hook, priority=10)
def register_momentum_hook(self, momentum_config):
if momentum_config is None:
......@@ -412,7 +412,7 @@ class BaseRunner(metaclass=ABCMeta):
hook = mmcv.build_from_cfg(momentum_config, HOOKS)
else:
hook = momentum_config
self.register_hook(hook)
self.register_hook(hook, priority=30)
def register_optimizer_hook(self, optimizer_config):
if optimizer_config is None:
......@@ -422,7 +422,7 @@ class BaseRunner(metaclass=ABCMeta):
hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
else:
hook = optimizer_config
self.register_hook(hook)
self.register_hook(hook, priority=50)
def register_checkpoint_hook(self, checkpoint_config):
if checkpoint_config is None:
......@@ -432,7 +432,7 @@ class BaseRunner(metaclass=ABCMeta):
hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
else:
hook = checkpoint_config
self.register_hook(hook)
self.register_hook(hook, priority=70)
def register_logger_hooks(self, log_config):
if log_config is None:
......@@ -441,7 +441,7 @@ class BaseRunner(metaclass=ABCMeta):
for info in log_config['hooks']:
logger_hook = mmcv.build_from_cfg(
info, HOOKS, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority='VERY_LOW')
self.register_hook(logger_hook, priority=90)
def register_timer_hook(self, timer_config):
if timer_config is None:
......@@ -451,7 +451,20 @@ class BaseRunner(metaclass=ABCMeta):
hook = mmcv.build_from_cfg(timer_config_, HOOKS)
else:
hook = timer_config
self.register_hook(hook)
self.register_hook(hook, priority=80)
def register_custom_hooks(self, custom_config):
if custom_config is None:
return
if not isinstance(custom_config, list):
custom_config = [custom_config]
for item in custom_config:
if isinstance(item, dict):
self.register_hook_from_cfg(item)
else:
self.register_hook(item, priority='NORMAL')
def register_profiler_hook(self, profiler_config):
if profiler_config is None:
......@@ -469,17 +482,20 @@ class BaseRunner(metaclass=ABCMeta):
checkpoint_config=None,
log_config=None,
momentum_config=None,
timer_config=dict(type='IterTimerHook')):
"""Register default hooks for training.
Default hooks include:
- LrUpdaterHook
- MomentumUpdaterHook
- OptimizerStepperHook
- CheckpointSaverHook
- IterTimerHook
- LoggerHook(s)
timer_config=dict(type='IterTimerHook'),
custom_hooks_config=None):
"""Register default and custom hooks for training.
Default and custom hooks include:
Hooks Priority
- LrUpdaterHook 10
- MomentumUpdaterHook 30
- OptimizerStepperHook 50
- CheckpointSaverHook 70
- IterTimerHook 80
- LoggerHook(s) 90
- CustomHook(s) 50 (default)
"""
self.register_lr_hook(lr_config)
self.register_momentum_hook(momentum_config)
......@@ -487,3 +503,4 @@ class BaseRunner(metaclass=ABCMeta):
self.register_checkpoint_hook(checkpoint_config)
self.register_timer_hook(timer_config)
self.register_logger_hooks(log_config)
self.register_custom_hooks(custom_hooks_config)
......@@ -21,6 +21,7 @@ from torch.utils.data import DataLoader
from mmcv.runner import (CheckpointHook, EMAHook, IterTimerHook,
MlflowLoggerHook, PaviLoggerHook, WandbLoggerHook,
build_runner)
from mmcv.runner.hooks.hook import HOOKS, Hook
from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook,
CyclicLrUpdaterHook,
OneCycleLrUpdaterHook,
......@@ -123,6 +124,53 @@ def test_ema_hook():
shutil.rmtree(work_dir)
def test_custom_hook():
@HOOKS.register_module()
class ToyHook(Hook):
def __init__(self, info, *args, **kwargs):
super().__init__()
self.info = info
runner = _build_demo_runner_without_hook('EpochBasedRunner', max_epochs=1)
# test if custom_hooks is None
runner.register_custom_hooks(None)
assert len(runner.hooks) == 0
# test if custom_hooks is dict list
custom_hooks_cfg = [
dict(type='ToyHook', priority=51, info=51),
dict(type='ToyHook', priority=49, info=49)
]
runner.register_custom_hooks(custom_hooks_cfg)
assert [hook.info for hook in runner.hooks] == [49, 51]
# test if custom_hooks is object and without priority
runner.register_custom_hooks(ToyHook(info='default'))
assert len(runner.hooks) == 3 and runner.hooks[1].info == 'default'
shutil.rmtree(runner.work_dir)
runner = _build_demo_runner_without_hook('EpochBasedRunner', max_epochs=1)
# test register_training_hooks order
custom_hooks_cfg = [
dict(type='ToyHook', priority=1, info='custom 1'),
dict(type='ToyHook', priority=89, info='custom 89')
]
runner.register_training_hooks(
lr_config=ToyHook('lr'),
optimizer_config=ToyHook('optimizer'),
checkpoint_config=ToyHook('checkpoint'),
log_config=dict(interval=1, hooks=[dict(type='ToyHook', info='log')]),
momentum_config=ToyHook('momentum'),
timer_config=ToyHook('timer'),
custom_hooks_config=custom_hooks_cfg)
hooks_order = [
'custom 1', 'lr', 'momentum', 'optimizer', 'checkpoint', 'timer',
'custom 89', 'log'
]
assert [hook.info for hook in runner.hooks] == hooks_order
shutil.rmtree(runner.work_dir)
def test_pavi_hook():
sys.modules['pavi'] = MagicMock()
......@@ -867,7 +915,7 @@ def test_wandb_hook():
hook.wandb.join.assert_called_with()
def _build_demo_runner(runner_type='EpochBasedRunner',
def _build_demo_runner_without_hook(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None,
multi_optimziers=False):
......@@ -900,11 +948,6 @@ def _build_demo_runner(runner_type='EpochBasedRunner',
else:
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)
log_config = dict(
interval=1, hooks=[
dict(type='TextLoggerHook'),
])
tmp_dir = tempfile.mkdtemp()
runner = build_runner(
dict(type=runner_type),
......@@ -915,6 +958,22 @@ def _build_demo_runner(runner_type='EpochBasedRunner',
logger=logging.getLogger(),
max_epochs=max_epochs,
max_iters=max_iters))
return runner
def _build_demo_runner(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None,
multi_optimziers=False):
log_config = dict(
interval=1, hooks=[
dict(type='TextLoggerHook'),
])
runner = _build_demo_runner_without_hook(runner_type, max_epochs,
max_iters, multi_optimziers)
runner.register_checkpoint_hook(dict(interval=1))
runner.register_logger_hooks(log_config)
return runner
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment