Unverified Commit d4da3daa authored by Wang Xinjiang's avatar Wang Xinjiang Committed by GitHub
Browse files

Syncbuf (#447)

* More robust sync buffer hook

* More robust sync buffer hook

* Reformat
parent cec5aace
......@@ -10,11 +10,12 @@ from .memory import EmptyCacheHook
from .momentum_updater import MomentumUpdaterHook
from .optimizer import Fp16OptimizerHook, OptimizerHook
from .sampler_seed import DistSamplerSeedHook
from .sync_buffer import SyncBuffersHook
__all__ = [
'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook',
'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook',
'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
'WandbLoggerHook', 'MomentumUpdaterHook'
'WandbLoggerHook', 'MomentumUpdaterHook', 'SyncBuffersHook'
]
# Copyright (c) Open-MMLab. All rights reserved.
import torch.distributed as dist
from ..dist_utils import get_dist_info
from .hook import HOOKS, Hook
......@@ -19,7 +20,8 @@ class SyncBuffersHook(Hook):
def after_epoch(self, runner):
"""All-reduce model buffers at the end of each epoch."""
if self.distributed:
_, world_size = get_dist_info()
if self.distributed and world_size > 1:
buffers = runner.model.buffers()
world_size = dist.get_world_size()
for tensor in buffers:
......
......@@ -42,6 +42,14 @@ def test_pavi_hook():
iteration=5)
def test_sync_buffers_hook():
loader = DataLoader(torch.ones((5, 2)))
runner = _build_demo_runner()
runner.register_hook_from_cfg(dict(type='SyncBuffersHook'))
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
shutil.rmtree(runner.work_dir)
def test_momentum_runner_hook():
"""xdoctest -m tests/test_hooks.py test_momentum_runner_hook."""
sys.modules['pavi'] = MagicMock()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment