Unverified Commit 66604e83 authored by Wang Xinjiang's avatar Wang Xinjiang Committed by GitHub
Browse files

Add syncbuffer hook (#443)

* reformat

* reformat

* Add register hook from cfg

* docstring

* change according to comments
parent 903091ef
...@@ -271,6 +271,22 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -271,6 +271,22 @@ class BaseRunner(metaclass=ABCMeta):
if not inserted: if not inserted:
self._hooks.insert(0, hook) self._hooks.insert(0, hook)
def register_hook_from_cfg(self, hook_cfg):
"""Register a hook from its cfg.
Args:
hook_cfg (dict): Hook config. It should have at least keys 'type'
and 'priority' indicating its type and priority.
Notes:
The specific hook class to register should not use 'type' and
'priority' arguments during initialization.
"""
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
self.register_hook(hook, priority=priority)
def call_hook(self, fn_name): def call_hook(self, fn_name):
"""Call all hooks. """Call all hooks.
......
# Copyright (c) Open-MMLab. All rights reserved.
import torch.distributed as dist
from .hook import HOOKS, Hook
@HOOKS.register_module()
class SyncBuffersHook(Hook):
"""Synchronize model buffers such as running_mean and running_var in BN at
the end of each epoch.
Args:
distributed (bool): Whether distributed training is used. It is
effective only for distributed training. Defaults to True.
"""
def __init__(self, distributed=True):
self.distributed = distributed
def after_epoch(self, runner):
"""All-reduce model buffers at the end of each epoch."""
if self.distributed:
buffers = runner.model.buffers()
world_size = dist.get_world_size()
for tensor in buffers:
dist.all_reduce(tensor.div_(world_size))
...@@ -18,11 +18,7 @@ from torch.utils.data import DataLoader ...@@ -18,11 +18,7 @@ from torch.utils.data import DataLoader
from mmcv.runner import (EpochBasedRunner, IterTimerHook, MlflowLoggerHook, from mmcv.runner import (EpochBasedRunner, IterTimerHook, MlflowLoggerHook,
PaviLoggerHook, WandbLoggerHook) PaviLoggerHook, WandbLoggerHook)
from mmcv.runner.hooks.lr_updater import (CosineAnnealingLrUpdaterHook, from mmcv.runner.hooks.lr_updater import CosineRestartLrUpdaterHook
CosineRestartLrUpdaterHook,
CyclicLrUpdaterHook)
from mmcv.runner.hooks.momentum_updater import (
CosineAnnealingMomentumUpdaterHook, CyclicMomentumUpdaterHook)
def test_pavi_hook(): def test_pavi_hook():
...@@ -53,21 +49,23 @@ def test_momentum_runner_hook(): ...@@ -53,21 +49,23 @@ def test_momentum_runner_hook():
runner = _build_demo_runner() runner = _build_demo_runner()
# add momentum scheduler # add momentum scheduler
hook = CyclicMomentumUpdaterHook( hook_cfg = dict(
type='CyclicMomentumUpdaterHook',
by_epoch=False, by_epoch=False,
target_ratio=(0.85 / 0.95, 1), target_ratio=(0.85 / 0.95, 1),
cyclic_times=1, cyclic_times=1,
step_ratio_up=0.4) step_ratio_up=0.4)
runner.register_hook(hook) runner.register_hook_from_cfg(hook_cfg)
# add momentum LR scheduler # add momentum LR scheduler
hook = CyclicLrUpdaterHook( hook_cfg = dict(
type='CyclicLrUpdaterHook',
by_epoch=False, by_epoch=False,
target_ratio=(10, 1), target_ratio=(10, 1),
cyclic_times=1, cyclic_times=1,
step_ratio_up=0.4) step_ratio_up=0.4)
runner.register_hook(hook) runner.register_hook_from_cfg(hook_cfg)
runner.register_hook(IterTimerHook()) runner.register_hook_from_cfg(dict(type='IterTimerHook'))
# add pavi hook # add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True) hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
...@@ -101,19 +99,25 @@ def test_cosine_runner_hook(): ...@@ -101,19 +99,25 @@ def test_cosine_runner_hook():
runner = _build_demo_runner() runner = _build_demo_runner()
# add momentum scheduler # add momentum scheduler
hook = CosineAnnealingMomentumUpdaterHook(
hook_cfg = dict(
type='CosineAnnealingMomentumUpdaterHook',
min_momentum_ratio=0.99 / 0.95, min_momentum_ratio=0.99 / 0.95,
by_epoch=False, by_epoch=False,
warmup_iters=2, warmup_iters=2,
warmup_ratio=0.9 / 0.95) warmup_ratio=0.9 / 0.95)
runner.register_hook(hook) runner.register_hook_from_cfg(hook_cfg)
# add momentum LR scheduler # add momentum LR scheduler
hook = CosineAnnealingLrUpdaterHook( hook_cfg = dict(
by_epoch=False, min_lr_ratio=0, warmup_iters=2, warmup_ratio=0.9) type='CosineAnnealingLrUpdaterHook',
runner.register_hook(hook) by_epoch=False,
min_lr_ratio=0,
warmup_iters=2,
warmup_ratio=0.9)
runner.register_hook_from_cfg(hook_cfg)
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
runner.register_hook(IterTimerHook()) runner.register_hook(IterTimerHook())
# add pavi hook # add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True) hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook) runner.register_hook(hook)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment