Unverified Commit cebf1341 authored by Hang Zhang's avatar Hang Zhang Committed by GitHub
Browse files

Adapt SyncBN API from Other's Work (#52)

* update and fix bugs

* adapt syncbn api from other work

* typo
parent 67e153dd
......@@ -14,6 +14,8 @@ import platform
import subprocess
from torch.utils.ffi import create_extension
torch_ver = torch.__version__[:3]
lib_path = os.path.join(os.path.dirname(torch.__file__), 'lib')
cwd = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'encoding/')
encoding_lib_path = os.path.join(cwd, "lib")
......@@ -25,12 +27,18 @@ subprocess.check_call(clean_cmd)
# build CUDA library
os.environ['TORCH_BUILD_DIR'] = lib_path
if platform.system() == 'Darwin':
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.dylib')
if torch_ver == '0.3':
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.1.dylib')
else:
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.dylib')
ENCODING_LIB = os.path.join(cwd, 'lib/libENCODING.dylib')
else:
os.environ['CFLAGS'] = '-std=c99'
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.so')
if torch_ver == '0.3':
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.so.1')
else:
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.so')
ENCODING_LIB = os.path.join(cwd, 'lib/libENCODING.so')
build_all_cmd = ['bash', 'encoding/make.sh']
......
......@@ -4,7 +4,7 @@
encoding.functions
==================
.. automodule:: encoding.Functions
.. automodule:: encoding.functions
.. currentmodule:: encoding.functions
......
......@@ -9,8 +9,6 @@ Created by `Hang Zhang <http://hangzh.com/>`_
An optimized PyTorch package with CUDA backend.
.. note::
PyTorch compatible Synchronized Cross-GPU :class:`encoding.nn.SyncBatchNorm2d` and the `MNIST example <https://github.com/zhanghang1989/PyTorch-SyncBatchNorm>`_.
.. toctree::
:glob:
......
.. role:: hidden
:class: hidden-section
encoding.nn
===========
Customized NN modules in Encoding Package. For Synchronized Cross-GPU Batch Normalization, please visit :class:`encoding.nn.SyncBatchNorm2d`.
.. currentmodule:: encoding.nn
:hidden:`Encoding`
~~~~~~~~~~~~~~~~~~
.. autoclass:: Encoding
:members:
:hidden:`SyncBatchNorm2d`
~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: SyncBatchNorm2d
:members:
:hidden:`SyncBatchNorm1d`
~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: SyncBatchNorm1d
:members:
:hidden:`SyncBatchNorm3d`
~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: SyncBatchNorm3d
:members:
:hidden:`Inspiration`
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: Inspiration
:members:
:hidden:`UpsampleConv2d`
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: UpsampleConv2d
:members:
:hidden:`DilatedAvgPool2d`
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: DilatedAvgPool2d
:members:
:hidden:`GramMatrix`
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: GramMatrix
:members:
# -*- coding: utf-8 -*-
# File : comm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import queue
import collections
import threading
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
class FutureResult(object):
"""A thread-safe future implementation. Used only as one-to-one pipe."""
def __init__(self):
self._result = None
self._lock = threading.Lock()
self._cond = threading.Condition(self._lock)
def put(self, result):
with self._lock:
assert self._result is None, 'Previous result has\'t been fetched.'
self._result = result
self._cond.notify()
def get(self):
with self._lock:
if self._result is None:
self._cond.wait()
res = self._result
self._result = None
return res
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
class SlavePipe(_SlavePipeBase):
"""Pipe for master-slave communication."""
def run_slave(self, msg):
self.queue.put((self.identifier, msg))
ret = self.result.get()
self.queue.put(True)
return ret
class SyncMaster(object):
"""An abstract `SyncMaster` object.
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
and passed to a registered callback.
- After receiving the messages, the master device should gather the information and determine to message passed
back to each slave devices.
"""
def __init__(self, master_callback):
"""
Args:
master_callback: a callback to be invoked after having collected messages from slave devices.
"""
self._master_callback = master_callback
self._queue = queue.Queue()
self._registry = collections.OrderedDict()
self._activated = False
def register_slave(self, identifier):
"""
Register an slave device.
Args:
identifier: an identifier, usually is the device id.
Returns: a `SlavePipe` object which can be used to communicate with the master device.
"""
if self._activated:
assert self._queue.empty(), 'Queue is not clean before next initialization.'
self._activated = False
self._registry.clear()
future = FutureResult()
self._registry[identifier] = _MasterRegistry(future)
return SlavePipe(identifier, self._queue, future)
def run_master(self, master_msg):
"""
Main entry for the master device in each forward pass.
The messages were first collected from each devices (including the master device), and then
an callback will be invoked to compute the message to be sent back to each devices
(including the master device).
Args:
master_msg: the message that the master want to send to itself. This will be placed as the first
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
Returns: the message to be sent back to the master device.
"""
self._activated = True
intermediates = [(0, master_msg)]
for i in range(self.nr_slaves):
intermediates.append(self._queue.get())
results = self._master_callback(intermediates)
assert results[0][0] == 0, 'The first result should belongs to the master.'
for i, res in results:
if i == 0:
continue
self._registry[i].result.put(res)
for i in range(self.nr_slaves):
assert self._queue.get() is True
return results[0][1]
@property
def nr_slaves(self):
return len(self._registry)
......@@ -9,7 +9,6 @@
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
"""Synchronized Cross-GPU Batch Normalization Module"""
import functools
import collections
import threading
import torch
......@@ -22,52 +21,96 @@ from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
from ..functions import *
from ..parallel import allreduce
from .comm import SyncMaster
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'Module', 'Sequential', 'Conv1d',
'Conv2d', 'ConvTranspose2d', 'ReLU', 'Sigmoid', 'MaxPool2d', 'AvgPool2d',
'AdaptiveAvgPool2d', 'Dropout2d', 'Linear']
# Adapt from https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
_ChildMessage = collections.namedtuple('Message', ['sum', 'ssum', 'sum_size'])
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
class _SyncBatchNorm(_BatchNorm):
def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True):
super(_SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
self._sync_master = SyncMaster(self._data_parallel_master)
self._is_parallel = False
self._parallel_id = None
self._slave_pipe = None
self.sharedT = SharedTensor(torch.cuda.device_count())
def forward(self, input):
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
if not (self._is_parallel and self.training):
return batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
self.training, self.momentum, self.eps)
# Resize the input to (B, C, -1).
input_shape = input.size()
input = input.view(input_shape[0], self.num_features, -1)
if not self.training:
std = (self.running_var.clamp(self.eps)).sqrt()
output = batchnormeval(input, self.weight, self.bias, self.running_mean, std)
return output.view(input_shape)
input = input.view(input.size(0), self.num_features, -1)
# sum(x) and sum(x^2)
N = input.size(0) * input.size(2)
xsum, xsqsum = sum_square(input)
# all-reduce for global sum(x) and sum(x^2)
igpu = input.get_device()
self.sharedT.push(N, igpu, xsum, xsqsum)
N, xsum, xsqsum = self.sharedT.pull(igpu)
if self._parallel_id == 0:
mean, inv_std = self._sync_master.run_master(_ChildMessage(xsum, xsqsum, N))
else:
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(xsum, xsqsum, N))
# forward
return batchnormtrain(input, self.weight, self.bias, mean, 1.0/inv_std).view(input_shape)
# calculate mean, var
mean = xsum / N
sumvar = xsqsum - xsum * xsum / N
unbias_var = sumvar / (N - 1)
bias_var = sumvar / N
std = bias_var.clamp(self.eps).sqrt()
# update running_mean and var
self.running_mean = (1-self.momentum) * self.running_mean + self.momentum * mean.data
self.running_var = (1-self.momentum) * self.running_var + self.momentum * unbias_var.data
def __data_parallel_replicate__(self, ctx, copy_id):
self._is_parallel = True
self._parallel_id = copy_id
# forward
return batchnormtrain(input, self.weight, self.bias, mean, std).view(input_shape)
# parallel_id == 0 means master device.
if self._parallel_id == 0:
ctx.sync_master = self._sync_master
else:
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
def _data_parallel_master(self, intermediates):
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
# Always using same "device order" makes the ReduceAdd operation faster.
# Thanks to:: Tete Xiao (http://tetexiao.com/)
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
to_reduce = [i[1][:2] for i in intermediates]
to_reduce = [j for i in to_reduce for j in i] # flatten
target_gpus = [i[1].sum.get_device() for i in intermediates]
sum_size = sum([i[1].sum_size for i in intermediates])
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
outputs = []
for i, rec in enumerate(intermediates):
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
return outputs
def _compute_mean_std(self, sum_, ssum, size):
"""Compute the mean and standard-deviation with sum and square-sum. This method
also maintains the moving average on the master device."""
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
mean = sum_ / size
sumvar = ssum - sum_ * mean
unbias_var = sumvar / (size - 1)
bias_var = sumvar / size
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
return mean, bias_var.clamp(self.eps) ** -0.5
class BatchNorm1d(_SyncBatchNorm):
......@@ -82,13 +125,15 @@ class BatchNorm1d(_SyncBatchNorm):
class BatchNorm2d(_SyncBatchNorm):
r"""Cross-GPU Synchronized Batch normalization (SyncBN)
Standard BN [1]_ implementation only normalize the data within each device.
Standard BN [1]_ implementation only normalize the data within each device (GPU).
SyncBN normalizes the input within the whole mini-batch.
We follow the sync-onece implmentation described in the paper [2]_ .
Please see the design idea in the `notes <./notes/syncbn.html>`_.
.. note::
Please use ``CUDA_VISIBLE_DEVICES`` to select number of GPUs.
We adapt the awesome python API from another `PyTorch SyncBN Implementation
<https://github.com/vacancy/Synchronized-BatchNorm-PyTorch>`_ and provide
efficient CUDA backend.
.. math::
......@@ -125,9 +170,9 @@ class BatchNorm2d(_SyncBatchNorm):
.. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018*
Examples:
>>> # Use exactly the same as standard BatchNrom2d
>>> m = BatchNorm2d(100)
>>> net = torch.nn.DataParallel(m)
>>> encoding.parallel.patch_replication_callback(net)
>>> output = net(input)
"""
def _check_input_dim(self, input):
......@@ -148,11 +193,12 @@ class BatchNorm3d(_SyncBatchNorm):
class SharedTensor(object):
"""Shared Tensor for cross GPU all reduce operation"""
def __init__(self, nGPUs):
def __init__(self, nGPUs, op):
self.mutex = threading.Lock()
self.all_tasks_done = threading.Condition(self.mutex)
self.nGPUs = nGPUs
self._clear()
self.op = op
def _clear(self):
self.N = 0
......@@ -160,7 +206,7 @@ class SharedTensor(object):
self.push_tasks = self.nGPUs
self.reduce_tasks = self.nGPUs
def push(self, *inputs):
def __call__(self, *inputs):
if self.nGPUs <= 1:
return tuple(inputs)
# push from device
......@@ -177,15 +223,13 @@ class SharedTensor(object):
self.all_tasks_done.notify_all()
while self.push_tasks:
self.all_tasks_done.wait()
def pull(self, igpu):
# pull from device
with self.mutex:
if igpu == 0:
assert(len(self.dict) == self.nGPUs)
# flatten the tensors
self.list = [t for i in range(len(self.dict)) for t in self.dict[i]]
self.outlist = allreduce(2, *self.list)
self.outlist = self.op(2, *self.list)
self.reduce_tasks -= 1
else:
self.reduce_tasks -= 1
......
......@@ -10,6 +10,7 @@
"""Encoding Data Parallel"""
import threading
import functools
import torch
from torch.autograd import Variable, Function
import torch.cuda.comm as comm
......@@ -17,10 +18,11 @@ from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.parallel_apply import get_a_var
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion']
torch_ver = torch.__version__[:3]
__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion',
'patch_replication_callback']
def allreduce(*inputs):
"""Cross GPU all reduce autograd operation for calculate mean and
variance in SyncBN.
......@@ -94,6 +96,11 @@ class DataParallelModel(DataParallel):
def gather(self, outputs, output_device):
return outputs
def replicate(self, module, device_ids):
modules = super(DataParallelModel, self).replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
class DataParallelCriterion(DataParallel):
"""
......@@ -181,3 +188,61 @@ def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices
raise output
outputs.append(output)
return outputs
###########################################################################
# Adapted from Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
class CallbackContext(object):
pass
def execute_replication_callbacks(modules):
"""
Execute an replication callback `__data_parallel_replicate__` on each module created
by original replication.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Note that, as all modules are isomorphism, we assign each sub-module with a context
(shared among multiple copies of this module on different devices).
Through this context, different copies can share some information.
We guarantee that the callback on the master copy (the first copy) will be called ahead
of calling the callback of any slave copies.
"""
master_copy = modules[0]
nr_modules = len(list(master_copy.modules()))
ctxs = [CallbackContext() for _ in range(nr_modules)]
for i, module in enumerate(modules):
for j, m in enumerate(module.modules()):
if hasattr(m, '__data_parallel_replicate__'):
m.__data_parallel_replicate__(ctxs[j], i)
def patch_replication_callback(data_parallel):
"""
Monkey-patch an existing `DataParallel` object. Add the replication callback.
Useful when you have customized `DataParallel` implementation.
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
> patch_replication_callback(sync_bn)
# this is equivalent to
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
"""
assert isinstance(data_parallel, DataParallel)
old_replicate = data_parallel.replicate
@functools.wraps(old_replicate)
def new_replicate(module, device_ids):
modules = old_replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
data_parallel.replicate = new_replicate
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment