Unverified Commit 6fe37225 authored by Ma Zerun's avatar Ma Zerun Committed by GitHub
Browse files

Refine default hooks and custom hooks priority rank. (#1120)

* Refine default hooks and custom hooks priority rank.

* Add unit tests for custom hooks with string priority.

* Use priority `ABOVE_NORMAL` and `BELOW_NORMAL` instead of `HIGHER` and
`LOWER`.

And add unit tests for custom hook with the same priority as
default hooks.
parent d9effbd1
...@@ -394,7 +394,7 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -394,7 +394,7 @@ class BaseRunner(metaclass=ABCMeta):
hook = mmcv.build_from_cfg(lr_config, HOOKS) hook = mmcv.build_from_cfg(lr_config, HOOKS)
else: else:
hook = lr_config hook = lr_config
self.register_hook(hook, priority=10) self.register_hook(hook, priority='VERY_HIGH')
def register_momentum_hook(self, momentum_config): def register_momentum_hook(self, momentum_config):
if momentum_config is None: if momentum_config is None:
...@@ -415,7 +415,7 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -415,7 +415,7 @@ class BaseRunner(metaclass=ABCMeta):
hook = mmcv.build_from_cfg(momentum_config, HOOKS) hook = mmcv.build_from_cfg(momentum_config, HOOKS)
else: else:
hook = momentum_config hook = momentum_config
self.register_hook(hook, priority=30) self.register_hook(hook, priority='HIGH')
def register_optimizer_hook(self, optimizer_config): def register_optimizer_hook(self, optimizer_config):
if optimizer_config is None: if optimizer_config is None:
...@@ -425,7 +425,7 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -425,7 +425,7 @@ class BaseRunner(metaclass=ABCMeta):
hook = mmcv.build_from_cfg(optimizer_config, HOOKS) hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
else: else:
hook = optimizer_config hook = optimizer_config
self.register_hook(hook, priority=50) self.register_hook(hook, priority='ABOVE_NORMAL')
def register_checkpoint_hook(self, checkpoint_config): def register_checkpoint_hook(self, checkpoint_config):
if checkpoint_config is None: if checkpoint_config is None:
...@@ -435,7 +435,7 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -435,7 +435,7 @@ class BaseRunner(metaclass=ABCMeta):
hook = mmcv.build_from_cfg(checkpoint_config, HOOKS) hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
else: else:
hook = checkpoint_config hook = checkpoint_config
self.register_hook(hook, priority=70) self.register_hook(hook, priority='NORMAL')
def register_logger_hooks(self, log_config): def register_logger_hooks(self, log_config):
if log_config is None: if log_config is None:
...@@ -444,7 +444,7 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -444,7 +444,7 @@ class BaseRunner(metaclass=ABCMeta):
for info in log_config['hooks']: for info in log_config['hooks']:
logger_hook = mmcv.build_from_cfg( logger_hook = mmcv.build_from_cfg(
info, HOOKS, default_args=dict(interval=log_interval)) info, HOOKS, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority=90) self.register_hook(logger_hook, priority='VERY_LOW')
def register_timer_hook(self, timer_config): def register_timer_hook(self, timer_config):
if timer_config is None: if timer_config is None:
...@@ -454,7 +454,7 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -454,7 +454,7 @@ class BaseRunner(metaclass=ABCMeta):
hook = mmcv.build_from_cfg(timer_config_, HOOKS) hook = mmcv.build_from_cfg(timer_config_, HOOKS)
else: else:
hook = timer_config hook = timer_config
self.register_hook(hook, priority=80) self.register_hook(hook, priority='LOW')
def register_custom_hooks(self, custom_config): def register_custom_hooks(self, custom_config):
if custom_config is None: if custom_config is None:
...@@ -491,14 +491,26 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -491,14 +491,26 @@ class BaseRunner(metaclass=ABCMeta):
Default and custom hooks include: Default and custom hooks include:
Hooks Priority +----------------------+-------------------------+
- LrUpdaterHook 10 | Hooks | Priority |
- MomentumUpdaterHook 30 +======================+=========================+
- OptimizerStepperHook 50 | LrUpdaterHook | VERY_HIGH (10) |
- CheckpointSaverHook 70 +----------------------+-------------------------+
- IterTimerHook 80 | MomentumUpdaterHook | HIGH (30) |
- LoggerHook(s) 90 +----------------------+-------------------------+
- CustomHook(s) 50 (default) | OptimizerStepperHook | ABOVE_NORMAL (40) |
+----------------------+-------------------------+
| CheckpointSaverHook | NORMAL (50) |
+----------------------+-------------------------+
| IterTimerHook | LOW (70) |
+----------------------+-------------------------+
| LoggerHook(s) | VERY_LOW (90) |
+----------------------+-------------------------+
| CustomHook(s) | defaults to NORMAL (50) |
+----------------------+-------------------------+
If custom hooks have same priority with default hooks, custom hooks
will be triggered after default hooks.
""" """
self.register_lr_hook(lr_config) self.register_lr_hook(lr_config)
self.register_momentum_hook(momentum_config) self.register_momentum_hook(momentum_config)
......
...@@ -5,29 +5,35 @@ from enum import Enum ...@@ -5,29 +5,35 @@ from enum import Enum
class Priority(Enum): class Priority(Enum):
"""Hook priority levels. """Hook priority levels.
+------------+------------+ +--------------+------------+
| Level | Value | | Level | Value |
+============+============+ +==============+============+
| HIGHEST | 0 | | HIGHEST | 0 |
+------------+------------+ +--------------+------------+
| VERY_HIGH | 10 | | VERY_HIGH | 10 |
+------------+------------+ +--------------+------------+
| HIGH | 30 | | HIGH | 30 |
+------------+------------+ +--------------+------------+
| NORMAL | 50 | | ABOVE_NORMAL | 40 |
+------------+------------+ +--------------+------------+
| LOW | 70 | | NORMAL | 50 |
+------------+------------+ +--------------+------------+
| VERY_LOW | 90 | | BELOW_NORMAL | 60 |
+------------+------------+ +--------------+------------+
| LOWEST | 100 | | LOW | 70 |
+------------+------------+ +--------------+------------+
| VERY_LOW | 90 |
+--------------+------------+
| LOWEST | 100 |
+--------------+------------+
""" """
HIGHEST = 0 HIGHEST = 0
VERY_HIGH = 10 VERY_HIGH = 10
HIGH = 30 HIGH = 30
ABOVE_NORMAL = 40
NORMAL = 50 NORMAL = 50
BELOW_NORMAL = 60
LOW = 70 LOW = 70
VERY_LOW = 90 VERY_LOW = 90
LOWEST = 100 LOWEST = 100
......
...@@ -6,6 +6,7 @@ CommandLine: ...@@ -6,6 +6,7 @@ CommandLine:
""" """
import logging import logging
import os.path as osp import os.path as osp
import random
import re import re
import shutil import shutil
import sys import sys
...@@ -149,10 +150,27 @@ def test_custom_hook(): ...@@ -149,10 +150,27 @@ def test_custom_hook():
assert len(runner.hooks) == 3 and runner.hooks[1].info == 'default' assert len(runner.hooks) == 3 and runner.hooks[1].info == 'default'
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
runner = _build_demo_runner_without_hook('EpochBasedRunner', max_epochs=1)
# test custom_hooks with string priority setting
priority_ranks = [
'HIGHEST', 'VERY_HIGH', 'HIGH', 'ABOVE_NORMAL', 'NORMAL',
'BELOW_NORMAL', 'LOW', 'VERY_LOW', 'LOWEST'
]
random_priority_ranks = priority_ranks.copy()
random.shuffle(random_priority_ranks)
custom_hooks_cfg = [
dict(type='ToyHook', priority=rank, info=rank)
for rank in random_priority_ranks
]
runner.register_custom_hooks(custom_hooks_cfg)
assert [hook.info for hook in runner.hooks] == priority_ranks
shutil.rmtree(runner.work_dir)
runner = _build_demo_runner_without_hook('EpochBasedRunner', max_epochs=1) runner = _build_demo_runner_without_hook('EpochBasedRunner', max_epochs=1)
# test register_training_hooks order # test register_training_hooks order
custom_hooks_cfg = [ custom_hooks_cfg = [
dict(type='ToyHook', priority=1, info='custom 1'), dict(type='ToyHook', priority=1, info='custom 1'),
dict(type='ToyHook', priority='NORMAL', info='custom normal'),
dict(type='ToyHook', priority=89, info='custom 89') dict(type='ToyHook', priority=89, info='custom 89')
] ]
runner.register_training_hooks( runner.register_training_hooks(
...@@ -163,9 +181,11 @@ def test_custom_hook(): ...@@ -163,9 +181,11 @@ def test_custom_hook():
momentum_config=ToyHook('momentum'), momentum_config=ToyHook('momentum'),
timer_config=ToyHook('timer'), timer_config=ToyHook('timer'),
custom_hooks_config=custom_hooks_cfg) custom_hooks_config=custom_hooks_cfg)
# If custom hooks have same priority with default hooks, custom hooks
# will be triggered after default hooks.
hooks_order = [ hooks_order = [
'custom 1', 'lr', 'momentum', 'optimizer', 'checkpoint', 'timer', 'custom 1', 'lr', 'momentum', 'optimizer', 'checkpoint',
'custom 89', 'log' 'custom normal', 'timer', 'custom 89', 'log'
] ]
assert [hook.info for hook in runner.hooks] == hooks_order assert [hook.info for hook in runner.hooks] == hooks_order
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment