Unverified Commit 6fe37225 authored by Ma Zerun's avatar Ma Zerun Committed by GitHub
Browse files

Refine default hooks and custom hooks priority rank. (#1120)

* Refine default hooks and custom hooks priority rank.

* Add unit tests for custom hooks with string priority.

* Use priority `ABOVE_NORMAL` and `BELOW_NORMAL` instead of `HIGHER` and
`LOWER`.

And add unit tests for custom hook with the same priority as
default hooks.
parent d9effbd1
......@@ -394,7 +394,7 @@ class BaseRunner(metaclass=ABCMeta):
hook = mmcv.build_from_cfg(lr_config, HOOKS)
else:
hook = lr_config
self.register_hook(hook, priority=10)
self.register_hook(hook, priority='VERY_HIGH')
def register_momentum_hook(self, momentum_config):
if momentum_config is None:
......@@ -415,7 +415,7 @@ class BaseRunner(metaclass=ABCMeta):
hook = mmcv.build_from_cfg(momentum_config, HOOKS)
else:
hook = momentum_config
self.register_hook(hook, priority=30)
self.register_hook(hook, priority='HIGH')
def register_optimizer_hook(self, optimizer_config):
if optimizer_config is None:
......@@ -425,7 +425,7 @@ class BaseRunner(metaclass=ABCMeta):
hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
else:
hook = optimizer_config
self.register_hook(hook, priority=50)
self.register_hook(hook, priority='ABOVE_NORMAL')
def register_checkpoint_hook(self, checkpoint_config):
if checkpoint_config is None:
......@@ -435,7 +435,7 @@ class BaseRunner(metaclass=ABCMeta):
hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
else:
hook = checkpoint_config
self.register_hook(hook, priority=70)
self.register_hook(hook, priority='NORMAL')
def register_logger_hooks(self, log_config):
if log_config is None:
......@@ -444,7 +444,7 @@ class BaseRunner(metaclass=ABCMeta):
for info in log_config['hooks']:
logger_hook = mmcv.build_from_cfg(
info, HOOKS, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority=90)
self.register_hook(logger_hook, priority='VERY_LOW')
def register_timer_hook(self, timer_config):
if timer_config is None:
......@@ -454,7 +454,7 @@ class BaseRunner(metaclass=ABCMeta):
hook = mmcv.build_from_cfg(timer_config_, HOOKS)
else:
hook = timer_config
self.register_hook(hook, priority=80)
self.register_hook(hook, priority='LOW')
def register_custom_hooks(self, custom_config):
if custom_config is None:
......@@ -491,14 +491,26 @@ class BaseRunner(metaclass=ABCMeta):
Default and custom hooks include:
Hooks Priority
- LrUpdaterHook 10
- MomentumUpdaterHook 30
- OptimizerStepperHook 50
- CheckpointSaverHook 70
- IterTimerHook 80
- LoggerHook(s) 90
- CustomHook(s) 50 (default)
+----------------------+-------------------------+
| Hooks | Priority |
+======================+=========================+
| LrUpdaterHook | VERY_HIGH (10) |
+----------------------+-------------------------+
| MomentumUpdaterHook | HIGH (30) |
+----------------------+-------------------------+
| OptimizerStepperHook | ABOVE_NORMAL (40) |
+----------------------+-------------------------+
| CheckpointSaverHook | NORMAL (50) |
+----------------------+-------------------------+
| IterTimerHook | LOW (70) |
+----------------------+-------------------------+
| LoggerHook(s) | VERY_LOW (90) |
+----------------------+-------------------------+
| CustomHook(s) | defaults to NORMAL (50) |
+----------------------+-------------------------+
If custom hooks have same priority with default hooks, custom hooks
will be triggered after default hooks.
"""
self.register_lr_hook(lr_config)
self.register_momentum_hook(momentum_config)
......
......@@ -5,29 +5,35 @@ from enum import Enum
class Priority(Enum):
"""Hook priority levels.
+------------+------------+
| Level | Value |
+============+============+
| HIGHEST | 0 |
+------------+------------+
| VERY_HIGH | 10 |
+------------+------------+
| HIGH | 30 |
+------------+------------+
| NORMAL | 50 |
+------------+------------+
| LOW | 70 |
+------------+------------+
| VERY_LOW | 90 |
+------------+------------+
| LOWEST | 100 |
+------------+------------+
+--------------+------------+
| Level | Value |
+==============+============+
| HIGHEST | 0 |
+--------------+------------+
| VERY_HIGH | 10 |
+--------------+------------+
| HIGH | 30 |
+--------------+------------+
| ABOVE_NORMAL | 40 |
+--------------+------------+
| NORMAL | 50 |
+--------------+------------+
| BELOW_NORMAL | 60 |
+--------------+------------+
| LOW | 70 |
+--------------+------------+
| VERY_LOW | 90 |
+--------------+------------+
| LOWEST | 100 |
+--------------+------------+
"""
HIGHEST = 0
VERY_HIGH = 10
HIGH = 30
ABOVE_NORMAL = 40
NORMAL = 50
BELOW_NORMAL = 60
LOW = 70
VERY_LOW = 90
LOWEST = 100
......
......@@ -6,6 +6,7 @@ CommandLine:
"""
import logging
import os.path as osp
import random
import re
import shutil
import sys
......@@ -149,10 +150,27 @@ def test_custom_hook():
assert len(runner.hooks) == 3 and runner.hooks[1].info == 'default'
shutil.rmtree(runner.work_dir)
runner = _build_demo_runner_without_hook('EpochBasedRunner', max_epochs=1)
# test custom_hooks with string priority setting
priority_ranks = [
'HIGHEST', 'VERY_HIGH', 'HIGH', 'ABOVE_NORMAL', 'NORMAL',
'BELOW_NORMAL', 'LOW', 'VERY_LOW', 'LOWEST'
]
random_priority_ranks = priority_ranks.copy()
random.shuffle(random_priority_ranks)
custom_hooks_cfg = [
dict(type='ToyHook', priority=rank, info=rank)
for rank in random_priority_ranks
]
runner.register_custom_hooks(custom_hooks_cfg)
assert [hook.info for hook in runner.hooks] == priority_ranks
shutil.rmtree(runner.work_dir)
runner = _build_demo_runner_without_hook('EpochBasedRunner', max_epochs=1)
# test register_training_hooks order
custom_hooks_cfg = [
dict(type='ToyHook', priority=1, info='custom 1'),
dict(type='ToyHook', priority='NORMAL', info='custom normal'),
dict(type='ToyHook', priority=89, info='custom 89')
]
runner.register_training_hooks(
......@@ -163,9 +181,11 @@ def test_custom_hook():
momentum_config=ToyHook('momentum'),
timer_config=ToyHook('timer'),
custom_hooks_config=custom_hooks_cfg)
# If custom hooks have same priority with default hooks, custom hooks
# will be triggered after default hooks.
hooks_order = [
'custom 1', 'lr', 'momentum', 'optimizer', 'checkpoint', 'timer',
'custom 89', 'log'
'custom 1', 'lr', 'momentum', 'optimizer', 'checkpoint',
'custom normal', 'timer', 'custom 89', 'log'
]
assert [hook.info for hook in runner.hooks] == hooks_order
shutil.rmtree(runner.work_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment