Unverified Commit 7fa78e7a authored by Miao Zheng's avatar Miao Zheng Committed by GitHub
Browse files

add register_itertimer_hook function (#838)

* add register_itertimer_hook function

* set default value

* revise minors

* revise according to comments

* fix according to comments

* update

* update
parent f75a88c2
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import copy
import logging import logging
import os.path as osp import os.path as osp
import warnings import warnings
...@@ -11,7 +12,7 @@ import mmcv ...@@ -11,7 +12,7 @@ import mmcv
from ..parallel import is_module_wrapper from ..parallel import is_module_wrapper
from .checkpoint import load_checkpoint from .checkpoint import load_checkpoint
from .dist_utils import get_dist_info from .dist_utils import get_dist_info
from .hooks import HOOKS, Hook, IterTimerHook from .hooks import HOOKS, Hook
from .log_buffer import LogBuffer from .log_buffer import LogBuffer
from .priority import get_priority from .priority import get_priority
from .utils import get_time_str from .utils import get_time_str
...@@ -414,12 +415,23 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -414,12 +415,23 @@ class BaseRunner(metaclass=ABCMeta):
info, HOOKS, default_args=dict(interval=log_interval)) info, HOOKS, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority='VERY_LOW') self.register_hook(logger_hook, priority='VERY_LOW')
def register_timer_hook(self, timer_config):
if timer_config is None:
return
if isinstance(timer_config, dict):
timer_config_ = copy.deepcopy(timer_config)
hook = mmcv.buid_from_cfg(timer_config_, HOOKS)
else:
hook = timer_config
self.register_hook(hook)
def register_training_hooks(self, def register_training_hooks(self,
lr_config, lr_config,
optimizer_config=None, optimizer_config=None,
checkpoint_config=None, checkpoint_config=None,
log_config=None, log_config=None,
momentum_config=None): momentum_config=None,
timer_config=dict(type='IterTimerHook')):
"""Register default hooks for training. """Register default hooks for training.
Default hooks include: Default hooks include:
...@@ -435,5 +447,5 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -435,5 +447,5 @@ class BaseRunner(metaclass=ABCMeta):
self.register_momentum_hook(momentum_config) self.register_momentum_hook(momentum_config)
self.register_optimizer_hook(optimizer_config) self.register_optimizer_hook(optimizer_config)
self.register_checkpoint_hook(checkpoint_config) self.register_checkpoint_hook(checkpoint_config)
self.register_hook(IterTimerHook()) self.register_timer_hook(timer_config)
self.register_logger_hooks(log_config) self.register_logger_hooks(log_config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment