You need to sign in or sign up before continuing.
Commit 4dcec47d authored by Hang Zhang's avatar Hang Zhang
Browse files

sync once

parent c6dc6176
......@@ -8,6 +8,9 @@ How BN works?
BN layer was introduced in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_, which dramatically speed up the training process of the network (enables larger learning rate) and makes the network less sensitive to the weight initialization.
.. image:: http://hangzh.com/blog/images/bn1.png
:align: center
- Forward Pass:
For the input data :math:`X={x_1, ...x_N}`, the data are normalized to be zero-mean and unit-variance, then scale and shit:
......@@ -31,6 +34,9 @@ Why Synchronize BN?
- Standard Implementations of BN in public frameworks (suck as Caffe, MXNet, Torch, TF, PyTorch) are unsynchronized, which means that the data are normalized within each GPU. Therefore the `working batch-size` of the BN layer is `BatchSize/nGPU` (batch-size in each GPU).
.. image:: http://hangzh.com/blog/images/bn2.png
:align: center
- Since the `working batch-size` is typically large enough for standard vision tasks, such as classification and detection, there is no need to synchronize BN layer during the training. The synchronization will slow down the training.
- However, for the Semantic Segmentation task, the state-of-the-art approaches typically adopt dilated convoluton, which is very memory consuming. The `working bath-size` can be too small for BN layers (2 or 4 in each GPU) when using larger/deeper pre-trained networks, such as :class:`encoding.dilated.ResNet` or :class:`encoding.dilated.DenseNet`.
......@@ -47,8 +53,10 @@ Suppose we have :math:`K` number of GPUs, :math:`sum(x)_k` and :math:`sum(x^2)_k
* :math:`\frac{d_\ell}{d_{x_i}}=\frac{d_\ell}{d_{y_i}}\frac{\gamma}{\sigma}` can be calculated locally in each GPU.
* Calculate the gradient of :math:`sum(x)` and :math:`sum(x^2)` individually in each GPU :math:`\frac{d_\ell}{d_{sum(x)_k}}` and :math:`\frac{d_\ell}{d_{sum(x^2)_k}}`.
* Then Sync the gradient (automatically handled by :class:`encoding.parallel.allreduce`) and continue the backward.
* Then Sync the gradient (automatically handled by :class:`encoding.parallel.AllReduce`) and continue the backward.
.. image:: http://hangzh.com/blog/images/bn3.png
:align: center
Citation
--------
......
......@@ -55,9 +55,7 @@ class _sum_square(Function):
def sum_square(input):
r"""
Calculate sum of elements and sum of squares for Batch Normalization.
"""
r"""Calculate sum of elements and sum of squares for Batch Normalization"""
return _sum_square.apply(input)
......
......@@ -13,136 +13,66 @@ import threading
import torch
from torch.nn import Module, Sequential, Conv1d, Conv2d, ConvTranspose2d, \
ReLU, Sigmoid, MaxPool2d, AvgPool2d, AdaptiveAvgPool2d, Dropout2d, Linear
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.parameter import Parameter
from ..functions import batchnormtrain, batchnormeval, sum_square
from ..parallel import allreduce
# import standard layers for convinent use
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'Module', 'Sequential', 'Conv1d',
'Conv2d', 'ConvTranspose2d', 'ReLU', 'Sigmoid', 'MaxPool2d',
'AvgPool2d', 'AdaptiveAvgPool2d', 'Dropout2d', 'Linear']
#__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']
class BatchNorm1d(Module):
r"""Cross-GPU Synchronized Batch normalization (SyncBN)
Standard BN [1]_ implementation only normalize the data within each device.
SyncBN normalizes the input within the whole mini-batch.
We follow the sync-onece implmentation described in the paper [2]_ .
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
Args:
num_features: num_features from an expected input of size
`batch_size x num_features [x width]`
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape:
- Input: :math:`(N, C)` or :math:`(N, C, L)`
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
Examples:
>>> # Use exactly the same as standard BatchNrom1d
>>> m = nn.BatchNorm1d(100)
>>> output = m(input)
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
super(BatchNorm1d, self).__init__()
self.num_features = num_features
self.affine = affine
self.eps = eps
self.momentum = momentum
if self.affine:
self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
class _SyncBatchNorm(_BatchNorm):
def __init__(self, num_features, eps=1e-5, momentum=0.1, **kwargs):
super(_SyncBatchNorm, self).__init__(num_features, eps=1e-5, momentum=0.1, **kwargs)
# syncBN
self.writelock = threading.Lock()
nGPUs = torch.cuda.device_count()
self.xsum = SharedTensor(nGPUs)
self.xsquare = SharedTensor(nGPUs)
def reset_parameters(self):
self.running_mean.zero_()
self.running_var.fill_(1)
if self.affine:
self.weight.data.uniform_()
self.bias.data.zero_()
def __repr__(self):
return ('{name}({num_features}, eps={eps}, momentum={momentum},'
' affine={affine})'
.format(name=self.__class__.__name__, **self.__dict__))
def _check_input_dim(self, input):
if input.dim() != 3:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
self.sharedT = SharedTensor(nGPUs)
def forward(self, input):
self._check_input_dim(input)
if self.training:
# push the value
isum, isquare = sum_square(input.unsqueeze(3))
idxs = self.xsum.push(isum)
idxq = self.xsquare.push(isquare)
xsum = self.xsum[idxs]
xsquare = self.xsquare[idxq]
# calculate N
N = len(self.xsum)*input.size(0)*input.size(2)
mean = xsum / N
sumvar = xsquare - xsum * xsum / N
unbias_var = sumvar / (N - 1)
std = (sumvar / N + self.eps).sqrt()
# update running_mean and var
self.running_mean = (1-self.momentum) * self.running_mean \
+ self.momentum * mean.data
self.running_var = (1-self.momentum) * self.running_var + \
self.momentum * unbias_var.data
# forward
return batchnormtrain(input, self.weight,
self.bias, mean, std)
else:
std = (self.running_var + self.eps).sqrt()
return batchnormeval(input, self.weight, self.bias,
self.running_mean, std)
input_shape = input.size()
input = input.view(input_shape[0], self.num_features, -1)
if not self.training:
std = (self.running_var.clamp(self.eps)).sqrt()
output = batchnormeval(input, self.weight, self.bias, self.running_mean, std)
return output.view(input_shape)
# get global sum(x) and sum(x^2)
xsum, xsquare = self.sharedT(sum_square(input.unsqueeze(3)))
# calculate mean, var
N = len(self.sharedT) * input.size(0) * input.size(2)
mean = xsum / N
sumvar = xsquare - xsum * xsum / N
unbias_var = sumvar / (N - 1)
bias_var = sumvar / N
std = bias_var.clamp(self.eps).sqrt()
# update running_mean and var
self.running_mean = (1-self.momentum) * self.running_mean + self.momentum * mean.data
self.running_var = (1-self.momentum) * self.running_var + self.momentum * unbias_var.data
# forward
return batchnormtrain(input, self.weight, self.bias, mean, std).view(input_shape)
class BatchNorm1d(_SyncBatchNorm):
r"""Please see the docs in :class:`encoding.nn.BatchNorm2d`"""
def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError('expected 2D or 3D input (got {}D input)'
.format(input.dim()))
class BatchNorm2d(Module):
class BatchNorm2d(_SyncBatchNorm):
r"""Cross-GPU Synchronized Batch normalization (SyncBN)
Standard BN [1]_ implementation only normalize the data within each device.
SyncBN normalizes the input within the whole mini-batch.
We follow the sync-onece implmentation described in the paper [2]_ .
Please see the design idea in the `notes <./notes/syncbn.html>`_.
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
The mean and standard-deviation are calculated per-dimension over
The mean and standard-deviation are calculated per-channel over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
......@@ -177,78 +107,20 @@ class BatchNorm2d(Module):
>>> m = nn.BatchNorm2d(100)
>>> output = m(input)
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
super(BatchNorm2d, self).__init__()
self.num_features = num_features
self.affine = affine
self.eps = eps
self.momentum = momentum
if self.affine:
self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
self.writelock = threading.Lock()
nGPUs = torch.cuda.device_count()
self.xsum, self.xsquare = SharedTensor(nGPUs), SharedTensor(nGPUs)
def reset_parameters(self):
self.running_mean.zero_()
self.running_var.fill_(1)
if self.affine:
self.weight.data.uniform_()
self.bias.data.zero_()
def __repr__(self):
return ('{name}({num_features}, eps={eps}, momentum={momentum},'
' affine={affine})'
.format(name=self.__class__.__name__, **self.__dict__))
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
def forward(self, input):
self._check_input_dim(input)
if self.training:
# push the value
isum, isquare = sum_square(input)
idxs = self.xsum.push(isum)
idxq = self.xsquare.push(isquare)
xsum = self.xsum[idxs]
xsquare = self.xsquare[idxq]
# calculate N
N = len(self.xsum)*input.size(0)*input.size(2)*input.size(3)
mean = xsum / N
sumvar = xsquare - xsum * xsum / N
unbias_var = sumvar / (N - 1)
std = (sumvar / N + self.eps).sqrt()
# update running_mean and var
self.running_mean = (1-self.momentum) * self.running_mean \
+ self.momentum * mean.data
self.running_var = (1-self.momentum) * self.running_var + \
self.momentum * unbias_var.data
# forward
B, C, H, W = input.size()
output = batchnormtrain(
input.view(B, C, -1).contiguous(), self.weight,
self.bias, mean,
std)
return output.view(B, C, H, W)
else:
std = (self.running_var + self.eps).sqrt()
B, C, H, W = input.size()
return batchnormeval(input.view(B, C, -1).contiguous(), self.weight, self.bias,
self.running_mean, std).view(B, C, H, W)
class BatchNorm3d(_SyncBatchNorm):
r"""Please see the docs in :class:`encoding.nn.BatchNorm2d`"""
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'
.format(input.dim()))
class SharedTensor(object):
"""Shared Tensor
"""Shared Tensor for cross GPU communication
"""
def __init__(self, nGPUs):
self.mutex = threading.Lock()
......@@ -261,44 +133,37 @@ class SharedTensor(object):
self.push_tasks = self.nGPUs
self.reduce_tasks = self.nGPUs
def push(self, t):
"""push a Tensor
"""
def __call__(self, *inputs):
# push from device
with self.mutex:
if self.push_tasks == 0:
self._clear()
self.list.append(t)
idx = len(self.list) - 1
self.list.extend(list(*inputs))
idx = self.nGPUs - self.push_tasks
self.push_tasks -= 1
with self.all_tasks_done:
if self.push_tasks == 0:
self.all_tasks_done.notify_all()
while self.push_tasks:
self.all_tasks_done.wait()
return idx
def _reduce(self):
# pull from device
with self.mutex:
if self.reduce_tasks == self.nGPUs:
assert(len(self.list) == self.nGPUs)
self.outlist = allreduce(*self.list)
assert(len(self.list) == 2 * self.nGPUs)
self.list = allreduce(2, *self.list)
self.reduce_tasks -= 1
else:
self.reduce_tasks -= 1
with self.all_tasks_done:
if self.reduce_tasks == 0:
self.all_tasks_done.notify_all()
while self.reduce_tasks:
self.all_tasks_done.wait()
def __getitem__(self, idx):
self._reduce()
return self.outlist[idx]
# all reduce done
return self.list[2*idx], self.list[2*idx+1]
def __len__(self):
return len(self.list)
return self.nGPUs
def __repr__(self):
return ('SharedTensor')
......@@ -11,28 +11,42 @@
"""Encoding Data Parallel"""
import threading
import torch
from torch.autograd import Function
import torch.cuda.comm as comm
from torch.autograd import Variable
from torch.nn.modules import Module
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.scatter_gather import scatter_kwargs
from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.parallel_apply import parallel_apply
from torch.nn.parallel.parallel_apply import get_a_var
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
__all__ = ['allreduce', 'ModelDataParallel', 'CriterionDataParallel']
def allreduce(*inputs):
def allreduce(num_inputs, *inputs):
"""Cross GPU all reduce autograd operation for calculate mean and
variance in SyncBN.
"""
target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
result = ReduceAddCoalesced.apply(target_gpus[0], 1, *inputs)
target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]
result = ReduceAddCoalesced.apply(target_gpus[0], num_inputs, *inputs)
outputs = Broadcast.apply(target_gpus, *result)
assert len(outputs) == len(inputs)
return outputs
class ModelDataParallel(Module):
class Reduce(Function):
@staticmethod
def forward(ctx, *inputs):
ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
return comm.reduce_add(inputs)
@staticmethod
def backward(ctx, gradOutput):
return Broadcast.apply(ctx.target_gpus, gradOutput)
class DataParallelModel(DataParallel):
"""Implements data parallelism at the module level.
This container parallelizes the application of the given module by
......@@ -41,7 +55,7 @@ class ModelDataParallel(Module):
In the forward pass, the module is replicated on each device,
and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.
Note that the outputs are not gathered, please use compatible
:class:`encoding.parallel.CriterionDataParallel`.
:class:`encoding.parallel.DataParallelCriterion`.
The batch size should be larger than the number of GPUs used. It should
also be an integer multiple of the number of GPUs so that each chunk is
......@@ -58,49 +72,20 @@ class ModelDataParallel(Module):
Example::
>>> net = encoding.nn.ModelDataParallel(model, device_ids=[0, 1, 2])
>>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
>>> y = net(x)
"""
def __init__(self, module, device_ids=None, output_device=None, dim=0):
super(ModelDataParallel, self).__init__()
if device_ids is None:
device_ids = list(range(torch.cuda.device_count()))
if output_device is None:
output_device = device_ids[0]
self.dim = dim
self.module = module
self.device_ids = device_ids
self.output_device = output_device
self.master_mean, self.master_var = {}, {}
if len(self.device_ids) == 1:
self.module.cuda(device_ids[0])
def forward(self, *inputs, **kwargs):
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
replicas = self.replicate(self.module, \
self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, kwargs)
def gather(self, outputs, output_device):
return outputs
def replicate(self, module, device_ids):
return replicate(module, device_ids)
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs)
class CriterionDataParallel(Module):
class DataParallelCriterion(DataParallel):
"""
Calculate loss in multiple-GPUs, which balance the memory usage for
Semantic Segmentation.
The targets are splitted across the specified devices by chunking in
the batch dimension. Please use together with :class:`encoding.parallel.ModelDataParallel`.
the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
Reference:
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
......@@ -109,79 +94,67 @@ class CriterionDataParallel(Module):
Example::
>>> net = encoding.nn.ModelDataParallel(model, device_ids=[0, 1, 2])
>>> criterion = encoding.nn.CriterionDataParallel(criterion, device_ids=[0, 1, 2])
>>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
>>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2])
>>> y = net(x)
>>> loss = criterion(y, target)
"""
def __init__(self, module, device_ids=None, output_device=None, dim=0):
super(CriterionDataParallel, self).__init__()
if device_ids is None:
device_ids = list(range(torch.cuda.device_count()))
if output_device is None:
output_device = device_ids[0]
self.dim = dim
self.module = module
self.device_ids = device_ids
self.output_device = output_device
if len(self.device_ids) == 1:
self.module.cuda(device_ids[0])
def forward(self, inputs, *targets, **kwargs):
# input should be already scatterd
# scattering the targets instead
targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
if not self.device_ids:
return self.module(inputs, *targets, **kwargs)
targets, kwargs = inputs(targets, kwargs, self.device_ids)
if len(self.device_ids) == 1:
return self.module(inputs, *targets[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, targets, kwargs)
return ReduceAddCoalesced.apply(self.output_device, 1, *outputs) / len(outputs)
def replicate(self, module, device_ids):
return replicate(module, device_ids)
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
replicas = replicate(self.module, self.device_ids[:len(inputs)])
outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
return Reduce.apply(*outputs) / len(outputs)
def parallel_apply(self, replicas, inputs, targets, kwargs):
return _criterion_parallel_apply(replicas, inputs, targets, kwargs)
def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None):
def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
assert len(modules) == len(inputs)
assert len(targets) == len(inputs)
if kwargs_tup:
assert len(modules) == len(kwargs_tup)
else:
kwargs_tup = ({},) * len(modules)
# Fast track
if len(modules) == 1:
return (modules[0](*inputs[0], *targets[0], **kwargs_tup[0]), )
if devices is not None:
assert len(modules) == len(devices)
else:
devices = [None] * len(modules)
lock = threading.Lock()
results = {}
grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input, target, kwargs, results, lock):
def _worker(i, module, input, target, kwargs, device=None):
torch.set_grad_enabled(grad_enabled)
if device is None:
device = get_a_var(input).get_device()
try:
var_input = _get_a_var(input)
with torch.cuda.device_of(var_input):
output = module(input, *target, **kwargs)
output = module(*(input + target), **kwargs)
with lock:
results[i] = output
except Exception as e:
with lock:
results[i] = e
threads = [threading.Thread(target=_worker,
args=(i, module, input, target,
kwargs, results, lock),)
for i, (module, input, target, kwargs) in
enumerate(zip(modules, inputs, targets, kwargs_tup))]
if len(modules) > 1:
threads = [threading.Thread(target=_worker,
args=(i, module, input, target,
kwargs, device),)
for i, (module, input, target, kwargs) in
enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
for thread in threads:
thread.start()
for thread in threads:
thread.join()
outputs = []
for i in range(len(inputs)):
output = results[i]
......@@ -190,19 +163,3 @@ def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None):
outputs.append(output)
return outputs
def _get_a_var(obj):
if isinstance(obj, Variable):
return obj
if isinstance(obj, list) or isinstance(obj, tuple):
results = map(_get_a_var, obj)
for result in results:
if isinstance(result, Variable):
return result
if isinstance(obj, dict):
results = map(_get_a_var, obj.items())
for result in results:
if isinstance(result, Variable):
return result
return None
......@@ -23,7 +23,7 @@ class install(setuptools.command.install.install):
def run(self):
self.create_version_file()
setuptools.command.install.install.run(self)
subprocess.check_call("python tests/unit_test.py".split())
#subprocess.check_call("python tests/unit_test.py".split())
@staticmethod
def create_version_file():
global version, cwd
......
......@@ -69,7 +69,7 @@ def test_all_reduce():
X = [torch.DoubleTensor(2,4,4).uniform_(-0.5,0.5).cuda(i) for i in range(ngpu)]
for x in X:
x.requires_grad = True
Y = encoding.parallel.allreduce(*X)
Y = encoding.parallel.allreduce(1, *X)
assert (len(X) == len(Y))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment