Unverified Commit d4da3daa authored by Wang Xinjiang's avatar Wang Xinjiang Committed by GitHub
Browse files

Syncbuf (#447)

* More robust sync buffer hook

* More robust sync buffer hook

* Reformat
parent cec5aace
...@@ -10,11 +10,12 @@ from .memory import EmptyCacheHook ...@@ -10,11 +10,12 @@ from .memory import EmptyCacheHook
from .momentum_updater import MomentumUpdaterHook from .momentum_updater import MomentumUpdaterHook
from .optimizer import Fp16OptimizerHook, OptimizerHook from .optimizer import Fp16OptimizerHook, OptimizerHook
from .sampler_seed import DistSamplerSeedHook from .sampler_seed import DistSamplerSeedHook
from .sync_buffer import SyncBuffersHook
__all__ = [ __all__ = [
'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook', 'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook',
'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook', 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook',
'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
'WandbLoggerHook', 'MomentumUpdaterHook' 'WandbLoggerHook', 'MomentumUpdaterHook', 'SyncBuffersHook'
] ]
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import torch.distributed as dist import torch.distributed as dist
from ..dist_utils import get_dist_info
from .hook import HOOKS, Hook from .hook import HOOKS, Hook
...@@ -19,7 +20,8 @@ class SyncBuffersHook(Hook): ...@@ -19,7 +20,8 @@ class SyncBuffersHook(Hook):
def after_epoch(self, runner): def after_epoch(self, runner):
"""All-reduce model buffers at the end of each epoch.""" """All-reduce model buffers at the end of each epoch."""
if self.distributed: _, world_size = get_dist_info()
if self.distributed and world_size > 1:
buffers = runner.model.buffers() buffers = runner.model.buffers()
world_size = dist.get_world_size() world_size = dist.get_world_size()
for tensor in buffers: for tensor in buffers:
......
...@@ -42,6 +42,14 @@ def test_pavi_hook(): ...@@ -42,6 +42,14 @@ def test_pavi_hook():
iteration=5) iteration=5)
def test_sync_buffers_hook():
loader = DataLoader(torch.ones((5, 2)))
runner = _build_demo_runner()
runner.register_hook_from_cfg(dict(type='SyncBuffersHook'))
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
shutil.rmtree(runner.work_dir)
def test_momentum_runner_hook(): def test_momentum_runner_hook():
"""xdoctest -m tests/test_hooks.py test_momentum_runner_hook.""" """xdoctest -m tests/test_hooks.py test_momentum_runner_hook."""
sys.modules['pavi'] = MagicMock() sys.modules['pavi'] = MagicMock()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment