Commit 82356fd9 authored by Kai Chen's avatar Kai Chen
Browse files

support chunk when reducing grads

parent 3d2b79bd
from .dist_utils import (init_dist, reduce_grads, DistOptimizerHook,
DistSamplerSeedHook)
from .dist_utils import init_dist, allreduce_grads, DistOptimizerHook
from .misc import tensor2imgs, unmap, multi_apply
__all__ = [
'init_dist', 'reduce_grads', 'DistOptimizerHook', 'DistSamplerSeedHook',
'tensor2imgs', 'unmap', 'multi_apply'
'init_dist', 'allreduce_grads', 'DistOptimizerHook', 'tensor2imgs',
'unmap', 'multi_apply'
]
......@@ -4,9 +4,9 @@ from collections import OrderedDict
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.nn.utils import clip_grad
from mmcv.runner import Hook, OptimizerHook
from torch._utils import (_flatten_dense_tensors, _unflatten_dense_tensors,
_take_tensors)
from mmcv.runner import OptimizerHook
def init_dist(launcher, backend='nccl', **kwargs):
......@@ -38,59 +38,52 @@ def _init_dist_slurm(backend, **kwargs):
raise NotImplementedError
# modified from
# https://github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py#L9
def all_reduce_coalesced(tensors):
buckets = OrderedDict()
for tensor in tensors:
tp = tensor.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(tensor)
world_size = dist.get_world_size()
for tp in buckets:
bucket = buckets[tp]
coalesced = _flatten_dense_tensors(bucket)
dist.all_reduce(coalesced)
coalesced.div_(world_size)
for buf, synced in zip(bucket,
_unflatten_dense_tensors(coalesced, bucket)):
buf.copy_(synced)
def reduce_grads(model, coalesce=True):
def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
if bucket_size_mb > 0:
bucket_size_bytes = bucket_size_mb * 1024 * 1024
buckets = _take_tensors(tensors, bucket_size_bytes)
else:
buckets = OrderedDict()
for tensor in tensors:
tp = tensor.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(tensor)
buckets = buckets.values()
for bucket in buckets:
flat_tensors = _flatten_dense_tensors(bucket)
dist.all_reduce(flat_tensors)
flat_tensors.div_(world_size)
for tensor, synced in zip(
bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
tensor.copy_(synced)
def allreduce_grads(model, coalesce=True, bucket_size_mb=-1):
grads = [
param.grad.data for param in model.parameters()
if param.requires_grad and param.grad is not None
]
world_size = dist.get_world_size()
if coalesce:
all_reduce_coalesced(grads)
_allreduce_coalesced(grads, world_size, bucket_size_mb)
else:
world_size = dist.get_world_size()
for tensor in grads:
dist.all_reduce(tensor.div_(world_size))
class DistOptimizerHook(OptimizerHook):
def __init__(self, grad_clip=None, coalesce=True):
def __init__(self, grad_clip=None, coalesce=True, bucket_size_mb=-1):
self.grad_clip = grad_clip
self.coalesce = coalesce
self.bucket_size_mb = bucket_size_mb
def after_train_iter(self, runner):
runner.optimizer.zero_grad()
runner.outputs['loss'].backward()
reduce_grads(runner.model, self.coalesce)
allreduce_grads(runner.model, self.coalesce, self.bucket_size_mb)
if self.grad_clip is not None:
clip_grad.clip_grad_norm_(
filter(lambda p: p.requires_grad, runner.model.parameters()),
**self.grad_clip)
self.clip_grads(runner.model.parameters())
runner.optimizer.step()
class DistSamplerSeedHook(Hook):
def before_epoch(self, runner):
runner.data_loader.sampler.set_epoch(runner.epoch)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment