Unverified Commit 50a33950 authored by Wang Xinjiang's avatar Wang Xinjiang Committed by GitHub
Browse files

Add Sync buffer in CheckpointHook (#588)

* Add Sync buffer in CheckpointHook

* add reduce_params in fp16_utils.py

* change default value of sync_buffer to False

* Add world size check

* reset sync_buffer to false

* fix world_size

* Move dist functions into dist_utils.py

* fix small bugs

* Deprecation compatibility

* Change according to comments
parent 54c527ac
...@@ -3,7 +3,8 @@ from .base_runner import BaseRunner ...@@ -3,7 +3,8 @@ from .base_runner import BaseRunner
from .builder import RUNNERS, build_runner from .builder import RUNNERS, build_runner
from .checkpoint import (_load_checkpoint, load_checkpoint, load_state_dict, from .checkpoint import (_load_checkpoint, load_checkpoint, load_state_dict,
save_checkpoint, weights_to_cpu) save_checkpoint, weights_to_cpu)
from .dist_utils import get_dist_info, init_dist, master_only from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info,
init_dist, master_only)
from .epoch_based_runner import EpochBasedRunner, Runner from .epoch_based_runner import EpochBasedRunner, Runner
from .fp16_utils import auto_fp16, force_fp32, wrap_fp16_model from .fp16_utils import auto_fp16, force_fp32, wrap_fp16_model
from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook, from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook,
...@@ -32,5 +33,5 @@ __all__ = [ ...@@ -32,5 +33,5 @@ __all__ = [
'build_optimizer', 'build_optimizer_constructor', 'IterLoader', 'build_optimizer', 'build_optimizer_constructor', 'IterLoader',
'set_random_seed', 'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'set_random_seed', 'auto_fp16', 'force_fp32', 'wrap_fp16_model',
'Fp16OptimizerHook', 'SyncBuffersHook', 'EMAHook', 'build_runner', 'Fp16OptimizerHook', 'SyncBuffersHook', 'EMAHook', 'build_runner',
'RUNNERS' 'RUNNERS', 'allreduce_grads', 'allreduce_params'
] ]
...@@ -2,10 +2,13 @@ ...@@ -2,10 +2,13 @@
import functools import functools
import os import os
import subprocess import subprocess
from collections import OrderedDict
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch import distributed as dist
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)
from mmcv.utils import TORCH_VERSION from mmcv.utils import TORCH_VERSION
...@@ -94,3 +97,71 @@ def master_only(func): ...@@ -94,3 +97,71 @@ def master_only(func):
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper
def allreduce_params(params, coalesce=True, bucket_size_mb=-1):
"""Allreduce parameters.
Args:
params (list[torch.Parameters]): List of parameters or buffers of a
model.
coalesce (bool, optional): Whether allreduce parameters as a whole.
Defaults to True.
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
Defaults to -1.
"""
_, world_size = get_dist_info()
if world_size == 1:
return
params = [param.data for param in params]
if coalesce:
_allreduce_coalesced(params, world_size, bucket_size_mb)
else:
for tensor in params:
dist.all_reduce(tensor.div_(world_size))
def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
"""Allreduce gradients.
Args:
params (list[torch.Parameters]): List of parameters of a model
coalesce (bool, optional): Whether allreduce parameters as a whole.
Defaults to True.
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
Defaults to -1.
"""
grads = [
param.grad.data for param in params
if param.requires_grad and param.grad is not None
]
_, world_size = get_dist_info()
if world_size == 1:
return
if coalesce:
_allreduce_coalesced(grads, world_size, bucket_size_mb)
else:
for tensor in grads:
dist.all_reduce(tensor.div_(world_size))
def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
if bucket_size_mb > 0:
bucket_size_bytes = bucket_size_mb * 1024 * 1024
buckets = _take_tensors(tensors, bucket_size_bytes)
else:
buckets = OrderedDict()
for tensor in tensors:
tp = tensor.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(tensor)
buckets = buckets.values()
for bucket in buckets:
flat_tensors = _flatten_dense_tensors(bucket)
dist.all_reduce(flat_tensors)
flat_tensors.div_(world_size)
for tensor, synced in zip(
bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
tensor.copy_(synced)
import functools import functools
from collections import OrderedDict, abc import warnings
from collections import abc
from inspect import getfullargspec from inspect import getfullargspec
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors) from .dist_utils import allreduce_grads as _allreduce_grads
def cast_tensor_type(inputs, src_type, dst_type): def cast_tensor_type(inputs, src_type, dst_type):
...@@ -197,48 +197,11 @@ def force_fp32(apply_to=None, out_fp16=False): ...@@ -197,48 +197,11 @@ def force_fp32(apply_to=None, out_fp16=False):
return force_fp32_wrapper return force_fp32_wrapper
def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
if bucket_size_mb > 0:
bucket_size_bytes = bucket_size_mb * 1024 * 1024
buckets = _take_tensors(tensors, bucket_size_bytes)
else:
buckets = OrderedDict()
for tensor in tensors:
tp = tensor.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(tensor)
buckets = buckets.values()
for bucket in buckets:
flat_tensors = _flatten_dense_tensors(bucket)
dist.all_reduce(flat_tensors)
flat_tensors.div_(world_size)
for tensor, synced in zip(
bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
tensor.copy_(synced)
def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
"""Allreduce gradients. warnings.warning(
'"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be '
Args: 'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads')
params (list[torch.Parameters]): List of parameters of a model _allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb)
coalesce (bool, optional): Whether allreduce parameters as a whole.
Defaults to True.
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
Defaults to -1.
"""
grads = [
param.grad.data for param in params
if param.requires_grad and param.grad is not None
]
world_size = dist.get_world_size()
if coalesce:
_allreduce_coalesced(grads, world_size, bucket_size_mb)
else:
for tensor in grads:
dist.all_reduce(tensor.div_(world_size))
def wrap_fp16_model(model): def wrap_fp16_model(model):
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import os import os
from ..dist_utils import master_only from ..dist_utils import allreduce_params, master_only
from .hook import HOOKS, Hook from .hook import HOOKS, Hook
...@@ -24,6 +24,8 @@ class CheckpointHook(Hook): ...@@ -24,6 +24,8 @@ class CheckpointHook(Hook):
In some cases we want only the latest few checkpoints and would In some cases we want only the latest few checkpoints and would
like to delete old ones to save the disk space. like to delete old ones to save the disk space.
Default: -1, which means unlimited. Default: -1, which means unlimited.
sync_buffer (bool): Whether to synchronize buffers in different
gpus. Default: False.
""" """
def __init__(self, def __init__(self,
...@@ -32,6 +34,7 @@ class CheckpointHook(Hook): ...@@ -32,6 +34,7 @@ class CheckpointHook(Hook):
save_optimizer=True, save_optimizer=True,
out_dir=None, out_dir=None,
max_keep_ckpts=-1, max_keep_ckpts=-1,
sync_buffer=False,
**kwargs): **kwargs):
self.interval = interval self.interval = interval
self.by_epoch = by_epoch self.by_epoch = by_epoch
...@@ -39,52 +42,50 @@ class CheckpointHook(Hook): ...@@ -39,52 +42,50 @@ class CheckpointHook(Hook):
self.out_dir = out_dir self.out_dir = out_dir
self.max_keep_ckpts = max_keep_ckpts self.max_keep_ckpts = max_keep_ckpts
self.args = kwargs self.args = kwargs
self.sync_buffer = sync_buffer
@master_only
def after_train_epoch(self, runner): def after_train_epoch(self, runner):
if not self.by_epoch or not self.every_n_epochs(runner, self.interval): if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
return return
runner.logger.info(f'Saving checkpoint at {runner.epoch + 1} epochs') runner.logger.info(f'Saving checkpoint at {runner.epoch + 1} epochs')
if self.sync_buffer:
allreduce_params(runner.model.buffers())
self._save_checkpoint(runner)
@master_only
def _save_checkpoint(self, runner):
"""Save the current checkpoint and delete unwanted checkpoint."""
if not self.out_dir: if not self.out_dir:
self.out_dir = runner.work_dir self.out_dir = runner.work_dir
runner.save_checkpoint( runner.save_checkpoint(
self.out_dir, save_optimizer=self.save_optimizer, **self.args) self.out_dir, save_optimizer=self.save_optimizer, **self.args)
# remove other checkpoints # remove other checkpoints
if self.max_keep_ckpts > 0: if self.max_keep_ckpts > 0:
filename_tmpl = self.args.get('filename_tmpl', 'epoch_{}.pth') if self.by_epoch:
current_epoch = runner.epoch + 1 name = 'epoch_{}.pth'
for epoch in range(current_epoch - self.max_keep_ckpts, 0, -1): current_ckpt = runner.epoch + 1
else:
name = 'iter_{}.pth'
current_ckpt = runner.iter + 1
redundant_ckpts = range(
current_ckpt - self.max_keep_ckpts * self.interval, 0,
-self.interval)
filename_tmpl = self.args.get('filename_tmpl', name)
for _step in redundant_ckpts:
ckpt_path = os.path.join(self.out_dir, ckpt_path = os.path.join(self.out_dir,
filename_tmpl.format(epoch)) filename_tmpl.format(_step))
if os.path.exists(ckpt_path): if os.path.exists(ckpt_path):
os.remove(ckpt_path) os.remove(ckpt_path)
else: else:
break break
@master_only
def after_train_iter(self, runner): def after_train_iter(self, runner):
if self.by_epoch or not self.every_n_iters(runner, self.interval): if self.by_epoch or not self.every_n_iters(runner, self.interval):
return return
runner.logger.info( runner.logger.info(
f'Saving checkpoint at {runner.iter + 1} iterations') f'Saving checkpoint at {runner.iter + 1} iterations')
if not self.out_dir: if self.sync_buffer:
self.out_dir = runner.work_dir allreduce_params(runner.model.buffers())
runner.save_checkpoint( self._save_checkpoint(runner)
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
# remove other checkpoints
if self.max_keep_ckpts > 0:
filename_tmpl = self.args.get('filename_tmpl', 'iter_{}.pth')
current_iter = runner.iter + 1
for _iter in range(
current_iter - self.max_keep_ckpts * self.interval, 0,
-self.interval):
ckpt_path = os.path.join(self.out_dir,
filename_tmpl.format(_iter))
if os.path.exists(ckpt_path):
os.remove(ckpt_path)
else:
break
...@@ -5,7 +5,8 @@ from itertools import chain ...@@ -5,7 +5,8 @@ from itertools import chain
from torch.nn.utils import clip_grad from torch.nn.utils import clip_grad
from ..fp16_utils import allreduce_grads, wrap_fp16_model from ..dist_utils import allreduce_grads
from ..fp16_utils import wrap_fp16_model
from .hook import HOOKS, Hook from .hook import HOOKS, Hook
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import torch.distributed as dist from ..dist_utils import allreduce_params
from ..dist_utils import get_dist_info
from .hook import HOOKS, Hook from .hook import HOOKS, Hook
...@@ -20,9 +18,5 @@ class SyncBuffersHook(Hook): ...@@ -20,9 +18,5 @@ class SyncBuffersHook(Hook):
def after_epoch(self, runner): def after_epoch(self, runner):
"""All-reduce model buffers at the end of each epoch.""" """All-reduce model buffers at the end of each epoch."""
_, world_size = get_dist_info() if self.distributed:
if self.distributed and world_size > 1: allreduce_params(runner.model.buffers())
buffers = runner.model.buffers()
world_size = dist.get_world_size()
for tensor in buffers:
dist.all_reduce(tensor.div_(world_size))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment