Unverified Commit 72e13e6a authored by Miao Zheng's avatar Miao Zheng Committed by GitHub
Browse files

fix baserunner bug (#840)

* fix baserunner bug

* add unit test

* hooks type check
parent 7fa78e7a
...@@ -420,7 +420,7 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -420,7 +420,7 @@ class BaseRunner(metaclass=ABCMeta):
return return
if isinstance(timer_config, dict): if isinstance(timer_config, dict):
timer_config_ = copy.deepcopy(timer_config) timer_config_ = copy.deepcopy(timer_config)
hook = mmcv.buid_from_cfg(timer_config_, HOOKS) hook = mmcv.build_from_cfg(timer_config_, HOOKS)
else: else:
hook = timer_config hook = timer_config
self.register_hook(hook) self.register_hook(hook)
......
...@@ -13,6 +13,7 @@ import torch.nn as nn ...@@ -13,6 +13,7 @@ import torch.nn as nn
from mmcv.parallel import MMDataParallel from mmcv.parallel import MMDataParallel
from mmcv.runner import (RUNNERS, EpochBasedRunner, IterBasedRunner, from mmcv.runner import (RUNNERS, EpochBasedRunner, IterBasedRunner,
build_runner) build_runner)
from mmcv.runner.hooks import IterTimerHook
class OldStyleModel(nn.Module): class OldStyleModel(nn.Module):
...@@ -257,3 +258,26 @@ def test_build_lr_momentum_hook(runner_class): ...@@ -257,3 +258,26 @@ def test_build_lr_momentum_hook(runner_class):
step_ratio_up=0.4) step_ratio_up=0.4)
runner.register_momentum_hook(mom_config) runner.register_momentum_hook(mom_config)
assert len(runner.hooks) == 8 assert len(runner.hooks) == 8
@pytest.mark.parametrize('runner_class', RUNNERS.module_dict.values())
def test_register_timer_hook(runner_class):
model = Model()
runner = runner_class(model=model, logger=logging.getLogger())
# test register None
timer_config = None
runner.register_timer_hook(timer_config)
assert len(runner.hooks) == 0
# test register IterTimerHook with config
timer_config = dict(type='IterTimerHook')
runner.register_timer_hook(timer_config)
assert len(runner.hooks) == 1
assert isinstance(runner.hooks[0], IterTimerHook)
# test register IterTimerHook
timer_config = IterTimerHook()
runner.register_timer_hook(timer_config)
assert len(runner.hooks) == 2
assert isinstance(runner.hooks[1], IterTimerHook)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment