Unverified Commit 1b15f022 authored by Ma Zerun's avatar Ma Zerun Committed by GitHub
Browse files

support print hooks before running. (#1123)

* support print using hooks before running.

* Support to print hook trigger stages.

* Print stage-wise hook infos. And make `stages` as class attribute of
`Hook`

* Add util function `is_method_overriden` and use it in
`Hook.get_trigger_stages`.

* Add unit tests.

* Move `is_method_overriden` to `mmcv/utils/misc.py`

* Improve hook info text.

* Add base_class argument type assertion, and fix some typos.

* Remove `get_trigger_stages` to `get_triggered_stages`

* Use f-string.
parent 227e7a73
......@@ -14,7 +14,7 @@ from .checkpoint import load_checkpoint
from .dist_utils import get_dist_info
from .hooks import HOOKS, Hook
from .log_buffer import LogBuffer
from .priority import get_priority
from .priority import Priority, get_priority
from .utils import get_time_str
......@@ -306,6 +306,29 @@ class BaseRunner(metaclass=ABCMeta):
for hook in self._hooks:
getattr(hook, fn_name)(self)
def get_hook_info(self):
# Get hooks info in each stage
stage_hook_map = {stage: [] for stage in Hook.stages}
for hook in self.hooks:
try:
priority = Priority(hook.priority).name
except ValueError:
priority = hook.priority
classname = hook.__class__.__name__
hook_info = f'({priority:<12}) {classname:<35}'
for trigger_stage in hook.get_triggered_stages():
stage_hook_map[trigger_stage].append(hook_info)
stage_hook_infos = []
for stage in Hook.stages:
hook_infos = stage_hook_map[stage]
if len(hook_infos) > 0:
info = f'{stage}:\n'
info += '\n'.join(hook_infos)
info += '\n -------------------- '
stage_hook_infos.append(info)
return '\n'.join(stage_hook_infos)
def load_checkpoint(self,
filename,
map_location='cpu',
......
......@@ -101,6 +101,8 @@ class EpochBasedRunner(BaseRunner):
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('Hooks will be executed in the following order:\n%s',
self.get_hook_info())
self.logger.info('workflow: %s, max: %d epochs', workflow,
self._max_epochs)
self.call_hook('before_run')
......
# Copyright (c) Open-MMLab. All rights reserved.
from mmcv.utils import Registry
from mmcv.utils import Registry, is_method_overridden
HOOKS = Registry('hook')
class Hook:
stages = ('before_run', 'before_train_epoch', 'before_train_iter',
'after_train_iter', 'after_train_epoch', 'before_val_epoch',
'before_val_iter', 'after_val_iter', 'after_val_epoch',
'after_run')
def before_run(self, runner):
pass
......@@ -65,3 +69,24 @@ class Hook:
def is_last_iter(self, runner):
return runner.iter + 1 == runner._max_iters
def get_triggered_stages(self):
trigger_stages = set()
for stage in Hook.stages:
if is_method_overridden(stage, Hook, self):
trigger_stages.add(stage)
# some methods will be triggered in multi stages
# use this dict to map method to stages.
method_stages_map = {
'before_epoch': ['before_train_epoch', 'before_val_epoch'],
'after_epoch': ['after_train_epoch', 'after_val_epoch'],
'before_iter': ['before_train_iter', 'before_val_iter'],
'after_iter': ['after_train_iter', 'after_val_iter'],
}
for method, map_stages in method_stages_map.items():
if is_method_overridden(method, Hook, self):
trigger_stages.update(map_stages)
return [stage for stage in Hook.stages if stage in trigger_stages]
......@@ -108,6 +108,8 @@ class IterBasedRunner(BaseRunner):
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('Hooks will be executed in the following order:\n%s',
self.get_hook_info())
self.logger.info('workflow: %s, max: %d iters', workflow,
self._max_iters)
self.call_hook('before_run')
......
......@@ -2,10 +2,11 @@
# Copyright (c) Open-MMLab. All rights reserved.
from .config import Config, ConfigDict, DictAction
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
import_modules_from_strings, is_list_of, is_seq_of, is_str,
is_tuple_of, iter_cast, list_cast, requires_executable,
requires_package, slice_list, to_1tuple, to_2tuple,
to_3tuple, to_4tuple, to_ntuple, tuple_cast)
import_modules_from_strings, is_list_of,
is_method_overridden, is_seq_of, is_str, is_tuple_of,
iter_cast, list_cast, requires_executable, requires_package,
slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple,
to_ntuple, tuple_cast)
from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
scandir, symlink)
from .progressbar import (ProgressBar, track_iter_progress,
......@@ -31,7 +32,8 @@ except ImportError:
'digit_version', 'get_git_hash', 'import_modules_from_strings',
'assert_dict_contains_subset', 'assert_attrs_equal',
'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script',
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple'
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
'is_method_overridden'
]
else:
from .env import collect_env
......@@ -60,5 +62,6 @@ else:
'get_git_hash', 'import_modules_from_strings', 'jit', 'skip_no_elena',
'assert_dict_contains_subset', 'assert_attrs_equal',
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
'assert_params_all_zeros', 'check_python_script'
'assert_params_all_zeros', 'check_python_script',
'is_method_overridden'
]
......@@ -333,3 +333,22 @@ def deprecated_api_warning(name_dict, cls_name=None):
return new_func
return api_warning_wrapper
def is_method_overridden(method, base_class, derived_class):
"""Check if a method of base class is overridden in derived class.
Args:
method (str): the method name to check.
base_class (type): the class of the base class.
derived_class (type | Any): the class or instance of the derived class.
"""
assert isinstance(base_class, type), \
"base_class doesn't accept instance, Please pass class instead."
if not isinstance(derived_class, type):
derived_class = derived_class.__class__
base_method = getattr(base_class, method)
derived_method = getattr(derived_class, method)
return derived_method != base_method
......@@ -1070,3 +1070,20 @@ def test_runner_with_revise_keys():
key_stripped = re.sub(r'^backbone\.', '', key)
assert torch.equal(model.state_dict()[key_stripped], state_dict[key])
os.remove(checkpoint_path)
def test_get_triggered_stages():
class ToyHook(Hook):
# test normal stage
def before_run():
pass
# test the method mapped to multi stages.
def after_epoch():
pass
hook = ToyHook()
# stages output have order, so here is list instead of set.
expected_stages = ['before_run', 'after_train_epoch', 'after_val_epoch']
assert hook.get_triggered_stages() == expected_stages
......@@ -160,3 +160,33 @@ def test_import_modules_from_strings():
['os.path', '_not_implemented'], allow_failed_imports=True)
assert imported[0] == osp
assert imported[1] is None
def test_is_method_overridden():
class Base:
def foo1():
pass
def foo2():
pass
class Sub(Base):
def foo1():
pass
# test passing sub class directly
assert mmcv.is_method_overridden('foo1', Base, Sub)
assert not mmcv.is_method_overridden('foo2', Base, Sub)
# test passing instance of sub class
sub_instance = Sub()
assert mmcv.is_method_overridden('foo1', Base, sub_instance)
assert not mmcv.is_method_overridden('foo2', Base, sub_instance)
# base_class should be a class, not instance
base_instance = Base()
with pytest.raises(AssertionError):
mmcv.is_method_overridden('foo1', base_instance, sub_instance)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment