Commit 4dcec47d authored by Hang Zhang's avatar Hang Zhang
Browse files

sync once

parent c6dc6176
...@@ -8,6 +8,9 @@ How BN works? ...@@ -8,6 +8,9 @@ How BN works?
BN layer was introduced in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_, which dramatically speed up the training process of the network (enables larger learning rate) and makes the network less sensitive to the weight initialization. BN layer was introduced in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_, which dramatically speed up the training process of the network (enables larger learning rate) and makes the network less sensitive to the weight initialization.
.. image:: http://hangzh.com/blog/images/bn1.png
:align: center
- Forward Pass: - Forward Pass:
For the input data :math:`X={x_1, ...x_N}`, the data are normalized to be zero-mean and unit-variance, then scale and shit: For the input data :math:`X={x_1, ...x_N}`, the data are normalized to be zero-mean and unit-variance, then scale and shit:
...@@ -31,6 +34,9 @@ Why Synchronize BN? ...@@ -31,6 +34,9 @@ Why Synchronize BN?
- Standard Implementations of BN in public frameworks (suck as Caffe, MXNet, Torch, TF, PyTorch) are unsynchronized, which means that the data are normalized within each GPU. Therefore the `working batch-size` of the BN layer is `BatchSize/nGPU` (batch-size in each GPU). - Standard Implementations of BN in public frameworks (suck as Caffe, MXNet, Torch, TF, PyTorch) are unsynchronized, which means that the data are normalized within each GPU. Therefore the `working batch-size` of the BN layer is `BatchSize/nGPU` (batch-size in each GPU).
.. image:: http://hangzh.com/blog/images/bn2.png
:align: center
- Since the `working batch-size` is typically large enough for standard vision tasks, such as classification and detection, there is no need to synchronize BN layer during the training. The synchronization will slow down the training. - Since the `working batch-size` is typically large enough for standard vision tasks, such as classification and detection, there is no need to synchronize BN layer during the training. The synchronization will slow down the training.
- However, for the Semantic Segmentation task, the state-of-the-art approaches typically adopt dilated convoluton, which is very memory consuming. The `working bath-size` can be too small for BN layers (2 or 4 in each GPU) when using larger/deeper pre-trained networks, such as :class:`encoding.dilated.ResNet` or :class:`encoding.dilated.DenseNet`. - However, for the Semantic Segmentation task, the state-of-the-art approaches typically adopt dilated convoluton, which is very memory consuming. The `working bath-size` can be too small for BN layers (2 or 4 in each GPU) when using larger/deeper pre-trained networks, such as :class:`encoding.dilated.ResNet` or :class:`encoding.dilated.DenseNet`.
...@@ -47,8 +53,10 @@ Suppose we have :math:`K` number of GPUs, :math:`sum(x)_k` and :math:`sum(x^2)_k ...@@ -47,8 +53,10 @@ Suppose we have :math:`K` number of GPUs, :math:`sum(x)_k` and :math:`sum(x^2)_k
* :math:`\frac{d_\ell}{d_{x_i}}=\frac{d_\ell}{d_{y_i}}\frac{\gamma}{\sigma}` can be calculated locally in each GPU. * :math:`\frac{d_\ell}{d_{x_i}}=\frac{d_\ell}{d_{y_i}}\frac{\gamma}{\sigma}` can be calculated locally in each GPU.
* Calculate the gradient of :math:`sum(x)` and :math:`sum(x^2)` individually in each GPU :math:`\frac{d_\ell}{d_{sum(x)_k}}` and :math:`\frac{d_\ell}{d_{sum(x^2)_k}}`. * Calculate the gradient of :math:`sum(x)` and :math:`sum(x^2)` individually in each GPU :math:`\frac{d_\ell}{d_{sum(x)_k}}` and :math:`\frac{d_\ell}{d_{sum(x^2)_k}}`.
* Then Sync the gradient (automatically handled by :class:`encoding.parallel.allreduce`) and continue the backward. * Then Sync the gradient (automatically handled by :class:`encoding.parallel.AllReduce`) and continue the backward.
.. image:: http://hangzh.com/blog/images/bn3.png
:align: center
Citation Citation
-------- --------
......
...@@ -55,9 +55,7 @@ class _sum_square(Function): ...@@ -55,9 +55,7 @@ class _sum_square(Function):
def sum_square(input): def sum_square(input):
r""" r"""Calculate sum of elements and sum of squares for Batch Normalization"""
Calculate sum of elements and sum of squares for Batch Normalization.
"""
return _sum_square.apply(input) return _sum_square.apply(input)
......
...@@ -13,136 +13,66 @@ import threading ...@@ -13,136 +13,66 @@ import threading
import torch import torch
from torch.nn import Module, Sequential, Conv1d, Conv2d, ConvTranspose2d, \ from torch.nn import Module, Sequential, Conv1d, Conv2d, ConvTranspose2d, \
ReLU, Sigmoid, MaxPool2d, AvgPool2d, AdaptiveAvgPool2d, Dropout2d, Linear ReLU, Sigmoid, MaxPool2d, AvgPool2d, AdaptiveAvgPool2d, Dropout2d, Linear
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from ..functions import batchnormtrain, batchnormeval, sum_square from ..functions import batchnormtrain, batchnormeval, sum_square
from ..parallel import allreduce from ..parallel import allreduce
# import standard layers for convinent use #__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'Module', 'Sequential', 'Conv1d',
'Conv2d', 'ConvTranspose2d', 'ReLU', 'Sigmoid', 'MaxPool2d',
'AvgPool2d', 'AdaptiveAvgPool2d', 'Dropout2d', 'Linear']
class BatchNorm1d(Module): class _SyncBatchNorm(_BatchNorm):
r"""Cross-GPU Synchronized Batch normalization (SyncBN) def __init__(self, num_features, eps=1e-5, momentum=0.1, **kwargs):
super(_SyncBatchNorm, self).__init__(num_features, eps=1e-5, momentum=0.1, **kwargs)
Standard BN [1]_ implementation only normalize the data within each device. # syncBN
SyncBN normalizes the input within the whole mini-batch.
We follow the sync-onece implmentation described in the paper [2]_ .
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
Args:
num_features: num_features from an expected input of size
`batch_size x num_features [x width]`
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape:
- Input: :math:`(N, C)` or :math:`(N, C, L)`
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
Examples:
>>> # Use exactly the same as standard BatchNrom1d
>>> m = nn.BatchNorm1d(100)
>>> output = m(input)
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
super(BatchNorm1d, self).__init__()
self.num_features = num_features
self.affine = affine
self.eps = eps
self.momentum = momentum
if self.affine:
self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
self.writelock = threading.Lock() self.writelock = threading.Lock()
nGPUs = torch.cuda.device_count() nGPUs = torch.cuda.device_count()
self.xsum = SharedTensor(nGPUs) self.sharedT = SharedTensor(nGPUs)
self.xsquare = SharedTensor(nGPUs)
def reset_parameters(self):
self.running_mean.zero_()
self.running_var.fill_(1)
if self.affine:
self.weight.data.uniform_()
self.bias.data.zero_()
def __repr__(self):
return ('{name}({num_features}, eps={eps}, momentum={momentum},'
' affine={affine})'
.format(name=self.__class__.__name__, **self.__dict__))
def _check_input_dim(self, input):
if input.dim() != 3:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
def forward(self, input): def forward(self, input):
self._check_input_dim(input) self._check_input_dim(input)
if self.training: input_shape = input.size()
# push the value input = input.view(input_shape[0], self.num_features, -1)
isum, isquare = sum_square(input.unsqueeze(3)) if not self.training:
idxs = self.xsum.push(isum) std = (self.running_var.clamp(self.eps)).sqrt()
idxq = self.xsquare.push(isquare) output = batchnormeval(input, self.weight, self.bias, self.running_mean, std)
xsum = self.xsum[idxs] return output.view(input_shape)
xsquare = self.xsquare[idxq] # get global sum(x) and sum(x^2)
# calculate N xsum, xsquare = self.sharedT(sum_square(input.unsqueeze(3)))
N = len(self.xsum)*input.size(0)*input.size(2) # calculate mean, var
mean = xsum / N N = len(self.sharedT) * input.size(0) * input.size(2)
sumvar = xsquare - xsum * xsum / N mean = xsum / N
unbias_var = sumvar / (N - 1) sumvar = xsquare - xsum * xsum / N
std = (sumvar / N + self.eps).sqrt() unbias_var = sumvar / (N - 1)
# update running_mean and var bias_var = sumvar / N
self.running_mean = (1-self.momentum) * self.running_mean \ std = bias_var.clamp(self.eps).sqrt()
+ self.momentum * mean.data # update running_mean and var
self.running_var = (1-self.momentum) * self.running_var + \ self.running_mean = (1-self.momentum) * self.running_mean + self.momentum * mean.data
self.momentum * unbias_var.data self.running_var = (1-self.momentum) * self.running_var + self.momentum * unbias_var.data
# forward # forward
return batchnormtrain(input, self.weight, return batchnormtrain(input, self.weight, self.bias, mean, std).view(input_shape)
self.bias, mean, std)
else:
std = (self.running_var + self.eps).sqrt() class BatchNorm1d(_SyncBatchNorm):
return batchnormeval(input, self.weight, self.bias, r"""Please see the docs in :class:`encoding.nn.BatchNorm2d`"""
self.running_mean, std) def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError('expected 2D or 3D input (got {}D input)'
.format(input.dim()))
class BatchNorm2d(Module): class BatchNorm2d(_SyncBatchNorm):
r"""Cross-GPU Synchronized Batch normalization (SyncBN) r"""Cross-GPU Synchronized Batch normalization (SyncBN)
Standard BN [1]_ implementation only normalize the data within each device. Standard BN [1]_ implementation only normalize the data within each device.
SyncBN normalizes the input within the whole mini-batch. SyncBN normalizes the input within the whole mini-batch.
We follow the sync-onece implmentation described in the paper [2]_ . We follow the sync-onece implmentation described in the paper [2]_ .
Please see the design idea in the `notes <./notes/syncbn.html>`_.
.. math:: .. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
The mean and standard-deviation are calculated per-dimension over The mean and standard-deviation are calculated per-channel over
the mini-batches and gamma and beta are learnable parameter vectors the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size). of size C (where C is the input size).
...@@ -177,78 +107,20 @@ class BatchNorm2d(Module): ...@@ -177,78 +107,20 @@ class BatchNorm2d(Module):
>>> m = nn.BatchNorm2d(100) >>> m = nn.BatchNorm2d(100)
>>> output = m(input) >>> output = m(input)
""" """
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
super(BatchNorm2d, self).__init__()
self.num_features = num_features
self.affine = affine
self.eps = eps
self.momentum = momentum
if self.affine:
self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
self.writelock = threading.Lock()
nGPUs = torch.cuda.device_count()
self.xsum, self.xsquare = SharedTensor(nGPUs), SharedTensor(nGPUs)
def reset_parameters(self):
self.running_mean.zero_()
self.running_var.fill_(1)
if self.affine:
self.weight.data.uniform_()
self.bias.data.zero_()
def __repr__(self):
return ('{name}({num_features}, eps={eps}, momentum={momentum},'
' affine={affine})'
.format(name=self.__class__.__name__, **self.__dict__))
def _check_input_dim(self, input): def _check_input_dim(self, input):
if input.dim() != 4: if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)' raise ValueError('expected 4D input (got {}D input)'
.format(input.dim())) .format(input.dim()))
def forward(self, input): class BatchNorm3d(_SyncBatchNorm):
self._check_input_dim(input) r"""Please see the docs in :class:`encoding.nn.BatchNorm2d`"""
if self.training: def _check_input_dim(self, input):
# push the value if input.dim() != 5:
isum, isquare = sum_square(input) raise ValueError('expected 5D input (got {}D input)'
idxs = self.xsum.push(isum) .format(input.dim()))
idxq = self.xsquare.push(isquare)
xsum = self.xsum[idxs]
xsquare = self.xsquare[idxq]
# calculate N
N = len(self.xsum)*input.size(0)*input.size(2)*input.size(3)
mean = xsum / N
sumvar = xsquare - xsum * xsum / N
unbias_var = sumvar / (N - 1)
std = (sumvar / N + self.eps).sqrt()
# update running_mean and var
self.running_mean = (1-self.momentum) * self.running_mean \
+ self.momentum * mean.data
self.running_var = (1-self.momentum) * self.running_var + \
self.momentum * unbias_var.data
# forward
B, C, H, W = input.size()
output = batchnormtrain(
input.view(B, C, -1).contiguous(), self.weight,
self.bias, mean,
std)
return output.view(B, C, H, W)
else:
std = (self.running_var + self.eps).sqrt()
B, C, H, W = input.size()
return batchnormeval(input.view(B, C, -1).contiguous(), self.weight, self.bias,
self.running_mean, std).view(B, C, H, W)
class SharedTensor(object): class SharedTensor(object):
"""Shared Tensor """Shared Tensor for cross GPU communication
""" """
def __init__(self, nGPUs): def __init__(self, nGPUs):
self.mutex = threading.Lock() self.mutex = threading.Lock()
...@@ -261,44 +133,37 @@ class SharedTensor(object): ...@@ -261,44 +133,37 @@ class SharedTensor(object):
self.push_tasks = self.nGPUs self.push_tasks = self.nGPUs
self.reduce_tasks = self.nGPUs self.reduce_tasks = self.nGPUs
def push(self, t): def __call__(self, *inputs):
"""push a Tensor # push from device
"""
with self.mutex: with self.mutex:
if self.push_tasks == 0: if self.push_tasks == 0:
self._clear() self._clear()
self.list.append(t) self.list.extend(list(*inputs))
idx = len(self.list) - 1 idx = self.nGPUs - self.push_tasks
self.push_tasks -= 1 self.push_tasks -= 1
with self.all_tasks_done: with self.all_tasks_done:
if self.push_tasks == 0: if self.push_tasks == 0:
self.all_tasks_done.notify_all() self.all_tasks_done.notify_all()
while self.push_tasks: while self.push_tasks:
self.all_tasks_done.wait() self.all_tasks_done.wait()
return idx # pull from device
def _reduce(self):
with self.mutex: with self.mutex:
if self.reduce_tasks == self.nGPUs: if self.reduce_tasks == self.nGPUs:
assert(len(self.list) == self.nGPUs) assert(len(self.list) == 2 * self.nGPUs)
self.outlist = allreduce(*self.list) self.list = allreduce(2, *self.list)
self.reduce_tasks -= 1 self.reduce_tasks -= 1
else: else:
self.reduce_tasks -= 1 self.reduce_tasks -= 1
with self.all_tasks_done: with self.all_tasks_done:
if self.reduce_tasks == 0: if self.reduce_tasks == 0:
self.all_tasks_done.notify_all() self.all_tasks_done.notify_all()
while self.reduce_tasks: while self.reduce_tasks:
self.all_tasks_done.wait() self.all_tasks_done.wait()
# all reduce done
def __getitem__(self, idx): return self.list[2*idx], self.list[2*idx+1]
self._reduce()
return self.outlist[idx]
def __len__(self): def __len__(self):
return len(self.list) return self.nGPUs
def __repr__(self): def __repr__(self):
return ('SharedTensor') return ('SharedTensor')
...@@ -11,28 +11,42 @@ ...@@ -11,28 +11,42 @@
"""Encoding Data Parallel""" """Encoding Data Parallel"""
import threading import threading
import torch import torch
from torch.autograd import Function
import torch.cuda.comm as comm
from torch.autograd import Variable from torch.autograd import Variable
from torch.nn.modules import Module from torch.nn.modules import Module
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.scatter_gather import scatter_kwargs from torch.nn.parallel.scatter_gather import scatter_kwargs
from torch.nn.parallel.replicate import replicate from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.parallel_apply import parallel_apply from torch.nn.parallel.parallel_apply import get_a_var
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
__all__ = ['allreduce', 'ModelDataParallel', 'CriterionDataParallel'] __all__ = ['allreduce', 'ModelDataParallel', 'CriterionDataParallel']
def allreduce(*inputs): def allreduce(num_inputs, *inputs):
"""Cross GPU all reduce autograd operation for calculate mean and """Cross GPU all reduce autograd operation for calculate mean and
variance in SyncBN. variance in SyncBN.
""" """
target_gpus = [inputs[i].get_device() for i in range(len(inputs))] target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]
result = ReduceAddCoalesced.apply(target_gpus[0], 1, *inputs) result = ReduceAddCoalesced.apply(target_gpus[0], num_inputs, *inputs)
outputs = Broadcast.apply(target_gpus, *result) outputs = Broadcast.apply(target_gpus, *result)
assert len(outputs) == len(inputs) assert len(outputs) == len(inputs)
return outputs return outputs
class ModelDataParallel(Module): class Reduce(Function):
@staticmethod
def forward(ctx, *inputs):
ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
return comm.reduce_add(inputs)
@staticmethod
def backward(ctx, gradOutput):
return Broadcast.apply(ctx.target_gpus, gradOutput)
class DataParallelModel(DataParallel):
"""Implements data parallelism at the module level. """Implements data parallelism at the module level.
This container parallelizes the application of the given module by This container parallelizes the application of the given module by
...@@ -41,7 +55,7 @@ class ModelDataParallel(Module): ...@@ -41,7 +55,7 @@ class ModelDataParallel(Module):
In the forward pass, the module is replicated on each device, In the forward pass, the module is replicated on each device,
and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module. and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.
Note that the outputs are not gathered, please use compatible Note that the outputs are not gathered, please use compatible
:class:`encoding.parallel.CriterionDataParallel`. :class:`encoding.parallel.DataParallelCriterion`.
The batch size should be larger than the number of GPUs used. It should The batch size should be larger than the number of GPUs used. It should
also be an integer multiple of the number of GPUs so that each chunk is also be an integer multiple of the number of GPUs so that each chunk is
...@@ -58,49 +72,20 @@ class ModelDataParallel(Module): ...@@ -58,49 +72,20 @@ class ModelDataParallel(Module):
Example:: Example::
>>> net = encoding.nn.ModelDataParallel(model, device_ids=[0, 1, 2]) >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
>>> y = net(x) >>> y = net(x)
""" """
def __init__(self, module, device_ids=None, output_device=None, dim=0): def gather(self, outputs, output_device):
super(ModelDataParallel, self).__init__()
if device_ids is None:
device_ids = list(range(torch.cuda.device_count()))
if output_device is None:
output_device = device_ids[0]
self.dim = dim
self.module = module
self.device_ids = device_ids
self.output_device = output_device
self.master_mean, self.master_var = {}, {}
if len(self.device_ids) == 1:
self.module.cuda(device_ids[0])
def forward(self, *inputs, **kwargs):
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
replicas = self.replicate(self.module, \
self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, kwargs)
return outputs return outputs
def replicate(self, module, device_ids):
return replicate(module, device_ids)
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs)
class CriterionDataParallel(Module): class DataParallelCriterion(DataParallel):
""" """
Calculate loss in multiple-GPUs, which balance the memory usage for Calculate loss in multiple-GPUs, which balance the memory usage for
Semantic Segmentation. Semantic Segmentation.
The targets are splitted across the specified devices by chunking in The targets are splitted across the specified devices by chunking in
the batch dimension. Please use together with :class:`encoding.parallel.ModelDataParallel`. the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
Reference: Reference:
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
...@@ -109,79 +94,67 @@ class CriterionDataParallel(Module): ...@@ -109,79 +94,67 @@ class CriterionDataParallel(Module):
Example:: Example::
>>> net = encoding.nn.ModelDataParallel(model, device_ids=[0, 1, 2]) >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
>>> criterion = encoding.nn.CriterionDataParallel(criterion, device_ids=[0, 1, 2]) >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2])
>>> y = net(x) >>> y = net(x)
>>> loss = criterion(y, target) >>> loss = criterion(y, target)
""" """
def __init__(self, module, device_ids=None, output_device=None, dim=0):
super(CriterionDataParallel, self).__init__()
if device_ids is None:
device_ids = list(range(torch.cuda.device_count()))
if output_device is None:
output_device = device_ids[0]
self.dim = dim
self.module = module
self.device_ids = device_ids
self.output_device = output_device
if len(self.device_ids) == 1:
self.module.cuda(device_ids[0])
def forward(self, inputs, *targets, **kwargs): def forward(self, inputs, *targets, **kwargs):
# input should be already scatterd # input should be already scatterd
# scattering the targets instead # scattering the targets instead
targets, kwargs = self.scatter(targets, kwargs, self.device_ids) if not self.device_ids:
return self.module(inputs, *targets, **kwargs)
targets, kwargs = inputs(targets, kwargs, self.device_ids)
if len(self.device_ids) == 1: if len(self.device_ids) == 1:
return self.module(inputs, *targets[0], **kwargs[0]) return self.module(inputs, *targets[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) replicas = replicate(self.module, self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, targets, kwargs) outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
return ReduceAddCoalesced.apply(self.output_device, 1, *outputs) / len(outputs) return Reduce.apply(*outputs) / len(outputs)
def replicate(self, module, device_ids):
return replicate(module, device_ids)
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def parallel_apply(self, replicas, inputs, targets, kwargs):
return _criterion_parallel_apply(replicas, inputs, targets, kwargs)
def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None):
assert len(modules) == len(inputs) assert len(modules) == len(inputs)
assert len(targets) == len(inputs) assert len(targets) == len(inputs)
if kwargs_tup: if kwargs_tup:
assert len(modules) == len(kwargs_tup) assert len(modules) == len(kwargs_tup)
else: else:
kwargs_tup = ({},) * len(modules) kwargs_tup = ({},) * len(modules)
# Fast track if devices is not None:
if len(modules) == 1: assert len(modules) == len(devices)
return (modules[0](*inputs[0], *targets[0], **kwargs_tup[0]), ) else:
devices = [None] * len(modules)
lock = threading.Lock() lock = threading.Lock()
results = {} results = {}
grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input, target, kwargs, results, lock): def _worker(i, module, input, target, kwargs, device=None):
torch.set_grad_enabled(grad_enabled)
if device is None:
device = get_a_var(input).get_device()
try: try:
var_input = _get_a_var(input)
with torch.cuda.device_of(var_input): with torch.cuda.device_of(var_input):
output = module(input, *target, **kwargs) output = module(*(input + target), **kwargs)
with lock: with lock:
results[i] = output results[i] = output
except Exception as e: except Exception as e:
with lock: with lock:
results[i] = e results[i] = e
threads = [threading.Thread(target=_worker, if len(modules) > 1:
args=(i, module, input, target, threads = [threading.Thread(target=_worker,
kwargs, results, lock),) args=(i, module, input, target,
for i, (module, input, target, kwargs) in kwargs, device),)
enumerate(zip(modules, inputs, targets, kwargs_tup))] for i, (module, input, target, kwargs) in
enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
for thread in threads:
thread.start()
for thread in threads:
thread.join()
outputs = [] outputs = []
for i in range(len(inputs)): for i in range(len(inputs)):
output = results[i] output = results[i]
...@@ -190,19 +163,3 @@ def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None): ...@@ -190,19 +163,3 @@ def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None):
outputs.append(output) outputs.append(output)
return outputs return outputs
def _get_a_var(obj):
if isinstance(obj, Variable):
return obj
if isinstance(obj, list) or isinstance(obj, tuple):
results = map(_get_a_var, obj)
for result in results:
if isinstance(result, Variable):
return result
if isinstance(obj, dict):
results = map(_get_a_var, obj.items())
for result in results:
if isinstance(result, Variable):
return result
return None
...@@ -23,7 +23,7 @@ class install(setuptools.command.install.install): ...@@ -23,7 +23,7 @@ class install(setuptools.command.install.install):
def run(self): def run(self):
self.create_version_file() self.create_version_file()
setuptools.command.install.install.run(self) setuptools.command.install.install.run(self)
subprocess.check_call("python tests/unit_test.py".split()) #subprocess.check_call("python tests/unit_test.py".split())
@staticmethod @staticmethod
def create_version_file(): def create_version_file():
global version, cwd global version, cwd
......
...@@ -69,7 +69,7 @@ def test_all_reduce(): ...@@ -69,7 +69,7 @@ def test_all_reduce():
X = [torch.DoubleTensor(2,4,4).uniform_(-0.5,0.5).cuda(i) for i in range(ngpu)] X = [torch.DoubleTensor(2,4,4).uniform_(-0.5,0.5).cuda(i) for i in range(ngpu)]
for x in X: for x in X:
x.requires_grad = True x.requires_grad = True
Y = encoding.parallel.allreduce(*X) Y = encoding.parallel.allreduce(1, *X)
assert (len(X) == len(Y)) assert (len(X) == len(Y))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment