"examples/vscode:/vscode.git/clone" did not exist on "4f7972957b02088b13c9d23d71646d2d0bab3e4e"
Unverified Commit cebf1341 authored by Hang Zhang's avatar Hang Zhang Committed by GitHub
Browse files

Adapt SyncBN API from Other's Work (#52)

* update and fix bugs

* adapt syncbn api from other work

* typo
parent 67e153dd
...@@ -14,6 +14,8 @@ import platform ...@@ -14,6 +14,8 @@ import platform
import subprocess import subprocess
from torch.utils.ffi import create_extension from torch.utils.ffi import create_extension
torch_ver = torch.__version__[:3]
lib_path = os.path.join(os.path.dirname(torch.__file__), 'lib') lib_path = os.path.join(os.path.dirname(torch.__file__), 'lib')
cwd = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'encoding/') cwd = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'encoding/')
encoding_lib_path = os.path.join(cwd, "lib") encoding_lib_path = os.path.join(cwd, "lib")
...@@ -25,11 +27,17 @@ subprocess.check_call(clean_cmd) ...@@ -25,11 +27,17 @@ subprocess.check_call(clean_cmd)
# build CUDA library # build CUDA library
os.environ['TORCH_BUILD_DIR'] = lib_path os.environ['TORCH_BUILD_DIR'] = lib_path
if platform.system() == 'Darwin': if platform.system() == 'Darwin':
if torch_ver == '0.3':
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.1.dylib')
else:
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.dylib') os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.dylib')
ENCODING_LIB = os.path.join(cwd, 'lib/libENCODING.dylib') ENCODING_LIB = os.path.join(cwd, 'lib/libENCODING.dylib')
else: else:
os.environ['CFLAGS'] = '-std=c99' os.environ['CFLAGS'] = '-std=c99'
if torch_ver == '0.3':
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.so.1')
else:
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.so') os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.so')
ENCODING_LIB = os.path.join(cwd, 'lib/libENCODING.so') ENCODING_LIB = os.path.join(cwd, 'lib/libENCODING.so')
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
encoding.functions encoding.functions
================== ==================
.. automodule:: encoding.Functions .. automodule:: encoding.functions
.. currentmodule:: encoding.functions .. currentmodule:: encoding.functions
......
...@@ -9,8 +9,6 @@ Created by `Hang Zhang <http://hangzh.com/>`_ ...@@ -9,8 +9,6 @@ Created by `Hang Zhang <http://hangzh.com/>`_
An optimized PyTorch package with CUDA backend. An optimized PyTorch package with CUDA backend.
.. note::
PyTorch compatible Synchronized Cross-GPU :class:`encoding.nn.SyncBatchNorm2d` and the `MNIST example <https://github.com/zhanghang1989/PyTorch-SyncBatchNorm>`_.
.. toctree:: .. toctree::
:glob: :glob:
......
.. role:: hidden
:class: hidden-section
encoding.nn
===========
Customized NN modules in Encoding Package. For Synchronized Cross-GPU Batch Normalization, please visit :class:`encoding.nn.SyncBatchNorm2d`.
.. currentmodule:: encoding.nn
:hidden:`Encoding`
~~~~~~~~~~~~~~~~~~
.. autoclass:: Encoding
:members:
:hidden:`SyncBatchNorm2d`
~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: SyncBatchNorm2d
:members:
:hidden:`SyncBatchNorm1d`
~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: SyncBatchNorm1d
:members:
:hidden:`SyncBatchNorm3d`
~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: SyncBatchNorm3d
:members:
:hidden:`Inspiration`
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: Inspiration
:members:
:hidden:`UpsampleConv2d`
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: UpsampleConv2d
:members:
:hidden:`DilatedAvgPool2d`
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: DilatedAvgPool2d
:members:
:hidden:`GramMatrix`
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: GramMatrix
:members:
# -*- coding: utf-8 -*-
# File : comm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import queue
import collections
import threading
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
class FutureResult(object):
"""A thread-safe future implementation. Used only as one-to-one pipe."""
def __init__(self):
self._result = None
self._lock = threading.Lock()
self._cond = threading.Condition(self._lock)
def put(self, result):
with self._lock:
assert self._result is None, 'Previous result has\'t been fetched.'
self._result = result
self._cond.notify()
def get(self):
with self._lock:
if self._result is None:
self._cond.wait()
res = self._result
self._result = None
return res
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
class SlavePipe(_SlavePipeBase):
"""Pipe for master-slave communication."""
def run_slave(self, msg):
self.queue.put((self.identifier, msg))
ret = self.result.get()
self.queue.put(True)
return ret
class SyncMaster(object):
"""An abstract `SyncMaster` object.
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
and passed to a registered callback.
- After receiving the messages, the master device should gather the information and determine to message passed
back to each slave devices.
"""
def __init__(self, master_callback):
"""
Args:
master_callback: a callback to be invoked after having collected messages from slave devices.
"""
self._master_callback = master_callback
self._queue = queue.Queue()
self._registry = collections.OrderedDict()
self._activated = False
def register_slave(self, identifier):
"""
Register an slave device.
Args:
identifier: an identifier, usually is the device id.
Returns: a `SlavePipe` object which can be used to communicate with the master device.
"""
if self._activated:
assert self._queue.empty(), 'Queue is not clean before next initialization.'
self._activated = False
self._registry.clear()
future = FutureResult()
self._registry[identifier] = _MasterRegistry(future)
return SlavePipe(identifier, self._queue, future)
def run_master(self, master_msg):
"""
Main entry for the master device in each forward pass.
The messages were first collected from each devices (including the master device), and then
an callback will be invoked to compute the message to be sent back to each devices
(including the master device).
Args:
master_msg: the message that the master want to send to itself. This will be placed as the first
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
Returns: the message to be sent back to the master device.
"""
self._activated = True
intermediates = [(0, master_msg)]
for i in range(self.nr_slaves):
intermediates.append(self._queue.get())
results = self._master_callback(intermediates)
assert results[0][0] == 0, 'The first result should belongs to the master.'
for i, res in results:
if i == 0:
continue
self._registry[i].result.put(res)
for i in range(self.nr_slaves):
assert self._queue.get() is True
return results[0][1]
@property
def nr_slaves(self):
return len(self._registry)
...@@ -9,7 +9,6 @@ ...@@ -9,7 +9,6 @@
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
"""Synchronized Cross-GPU Batch Normalization Module""" """Synchronized Cross-GPU Batch Normalization Module"""
import functools
import collections import collections
import threading import threading
import torch import torch
...@@ -22,52 +21,96 @@ from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast ...@@ -22,52 +21,96 @@ from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
from ..functions import * from ..functions import *
from ..parallel import allreduce from ..parallel import allreduce
from .comm import SyncMaster
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'Module', 'Sequential', 'Conv1d', __all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'Module', 'Sequential', 'Conv1d',
'Conv2d', 'ConvTranspose2d', 'ReLU', 'Sigmoid', 'MaxPool2d', 'AvgPool2d', 'Conv2d', 'ConvTranspose2d', 'ReLU', 'Sigmoid', 'MaxPool2d', 'AvgPool2d',
'AdaptiveAvgPool2d', 'Dropout2d', 'Linear'] 'AdaptiveAvgPool2d', 'Dropout2d', 'Linear']
# Adapt from https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
_ChildMessage = collections.namedtuple('Message', ['sum', 'ssum', 'sum_size'])
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
class _SyncBatchNorm(_BatchNorm): class _SyncBatchNorm(_BatchNorm):
def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True): def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True):
super(_SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) super(_SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
self._sync_master = SyncMaster(self._data_parallel_master)
self._is_parallel = False self._is_parallel = False
self._parallel_id = None self._parallel_id = None
self._slave_pipe = None self._slave_pipe = None
self.sharedT = SharedTensor(torch.cuda.device_count())
def forward(self, input): def forward(self, input):
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
if not (self._is_parallel and self.training):
return batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
self.training, self.momentum, self.eps)
# Resize the input to (B, C, -1). # Resize the input to (B, C, -1).
input_shape = input.size() input_shape = input.size()
input = input.view(input_shape[0], self.num_features, -1) input = input.view(input.size(0), self.num_features, -1)
if not self.training:
std = (self.running_var.clamp(self.eps)).sqrt()
output = batchnormeval(input, self.weight, self.bias, self.running_mean, std)
return output.view(input_shape)
# sum(x) and sum(x^2) # sum(x) and sum(x^2)
N = input.size(0) * input.size(2) N = input.size(0) * input.size(2)
xsum, xsqsum = sum_square(input) xsum, xsqsum = sum_square(input)
# all-reduce for global sum(x) and sum(x^2) # all-reduce for global sum(x) and sum(x^2)
igpu = input.get_device() if self._parallel_id == 0:
self.sharedT.push(N, igpu, xsum, xsqsum) mean, inv_std = self._sync_master.run_master(_ChildMessage(xsum, xsqsum, N))
N, xsum, xsqsum = self.sharedT.pull(igpu) else:
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(xsum, xsqsum, N))
# forward
return batchnormtrain(input, self.weight, self.bias, mean, 1.0/inv_std).view(input_shape)
# calculate mean, var
mean = xsum / N
sumvar = xsqsum - xsum * xsum / N
unbias_var = sumvar / (N - 1)
bias_var = sumvar / N
std = bias_var.clamp(self.eps).sqrt()
# update running_mean and var def __data_parallel_replicate__(self, ctx, copy_id):
self.running_mean = (1-self.momentum) * self.running_mean + self.momentum * mean.data self._is_parallel = True
self.running_var = (1-self.momentum) * self.running_var + self.momentum * unbias_var.data self._parallel_id = copy_id
# forward # parallel_id == 0 means master device.
return batchnormtrain(input, self.weight, self.bias, mean, std).view(input_shape) if self._parallel_id == 0:
ctx.sync_master = self._sync_master
else:
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
def _data_parallel_master(self, intermediates):
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
# Always using same "device order" makes the ReduceAdd operation faster.
# Thanks to:: Tete Xiao (http://tetexiao.com/)
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
to_reduce = [i[1][:2] for i in intermediates]
to_reduce = [j for i in to_reduce for j in i] # flatten
target_gpus = [i[1].sum.get_device() for i in intermediates]
sum_size = sum([i[1].sum_size for i in intermediates])
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
outputs = []
for i, rec in enumerate(intermediates):
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
return outputs
def _compute_mean_std(self, sum_, ssum, size):
"""Compute the mean and standard-deviation with sum and square-sum. This method
also maintains the moving average on the master device."""
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
mean = sum_ / size
sumvar = ssum - sum_ * mean
unbias_var = sumvar / (size - 1)
bias_var = sumvar / size
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
return mean, bias_var.clamp(self.eps) ** -0.5
class BatchNorm1d(_SyncBatchNorm): class BatchNorm1d(_SyncBatchNorm):
...@@ -82,13 +125,15 @@ class BatchNorm1d(_SyncBatchNorm): ...@@ -82,13 +125,15 @@ class BatchNorm1d(_SyncBatchNorm):
class BatchNorm2d(_SyncBatchNorm): class BatchNorm2d(_SyncBatchNorm):
r"""Cross-GPU Synchronized Batch normalization (SyncBN) r"""Cross-GPU Synchronized Batch normalization (SyncBN)
Standard BN [1]_ implementation only normalize the data within each device. Standard BN [1]_ implementation only normalize the data within each device (GPU).
SyncBN normalizes the input within the whole mini-batch. SyncBN normalizes the input within the whole mini-batch.
We follow the sync-onece implmentation described in the paper [2]_ . We follow the sync-onece implmentation described in the paper [2]_ .
Please see the design idea in the `notes <./notes/syncbn.html>`_. Please see the design idea in the `notes <./notes/syncbn.html>`_.
.. note:: .. note::
Please use ``CUDA_VISIBLE_DEVICES`` to select number of GPUs. We adapt the awesome python API from another `PyTorch SyncBN Implementation
<https://github.com/vacancy/Synchronized-BatchNorm-PyTorch>`_ and provide
efficient CUDA backend.
.. math:: .. math::
...@@ -125,9 +170,9 @@ class BatchNorm2d(_SyncBatchNorm): ...@@ -125,9 +170,9 @@ class BatchNorm2d(_SyncBatchNorm):
.. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018* .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018*
Examples: Examples:
>>> # Use exactly the same as standard BatchNrom2d
>>> m = BatchNorm2d(100) >>> m = BatchNorm2d(100)
>>> net = torch.nn.DataParallel(m) >>> net = torch.nn.DataParallel(m)
>>> encoding.parallel.patch_replication_callback(net)
>>> output = net(input) >>> output = net(input)
""" """
def _check_input_dim(self, input): def _check_input_dim(self, input):
...@@ -148,11 +193,12 @@ class BatchNorm3d(_SyncBatchNorm): ...@@ -148,11 +193,12 @@ class BatchNorm3d(_SyncBatchNorm):
class SharedTensor(object): class SharedTensor(object):
"""Shared Tensor for cross GPU all reduce operation""" """Shared Tensor for cross GPU all reduce operation"""
def __init__(self, nGPUs): def __init__(self, nGPUs, op):
self.mutex = threading.Lock() self.mutex = threading.Lock()
self.all_tasks_done = threading.Condition(self.mutex) self.all_tasks_done = threading.Condition(self.mutex)
self.nGPUs = nGPUs self.nGPUs = nGPUs
self._clear() self._clear()
self.op = op
def _clear(self): def _clear(self):
self.N = 0 self.N = 0
...@@ -160,7 +206,7 @@ class SharedTensor(object): ...@@ -160,7 +206,7 @@ class SharedTensor(object):
self.push_tasks = self.nGPUs self.push_tasks = self.nGPUs
self.reduce_tasks = self.nGPUs self.reduce_tasks = self.nGPUs
def push(self, *inputs): def __call__(self, *inputs):
if self.nGPUs <= 1: if self.nGPUs <= 1:
return tuple(inputs) return tuple(inputs)
# push from device # push from device
...@@ -177,15 +223,13 @@ class SharedTensor(object): ...@@ -177,15 +223,13 @@ class SharedTensor(object):
self.all_tasks_done.notify_all() self.all_tasks_done.notify_all()
while self.push_tasks: while self.push_tasks:
self.all_tasks_done.wait() self.all_tasks_done.wait()
def pull(self, igpu):
# pull from device # pull from device
with self.mutex: with self.mutex:
if igpu == 0: if igpu == 0:
assert(len(self.dict) == self.nGPUs) assert(len(self.dict) == self.nGPUs)
# flatten the tensors # flatten the tensors
self.list = [t for i in range(len(self.dict)) for t in self.dict[i]] self.list = [t for i in range(len(self.dict)) for t in self.dict[i]]
self.outlist = allreduce(2, *self.list) self.outlist = self.op(2, *self.list)
self.reduce_tasks -= 1 self.reduce_tasks -= 1
else: else:
self.reduce_tasks -= 1 self.reduce_tasks -= 1
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
"""Encoding Data Parallel""" """Encoding Data Parallel"""
import threading import threading
import functools
import torch import torch
from torch.autograd import Variable, Function from torch.autograd import Variable, Function
import torch.cuda.comm as comm import torch.cuda.comm as comm
...@@ -17,10 +18,11 @@ from torch.nn.parallel.data_parallel import DataParallel ...@@ -17,10 +18,11 @@ from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.parallel_apply import get_a_var from torch.nn.parallel.parallel_apply import get_a_var
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion']
torch_ver = torch.__version__[:3] torch_ver = torch.__version__[:3]
__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion',
'patch_replication_callback']
def allreduce(*inputs): def allreduce(*inputs):
"""Cross GPU all reduce autograd operation for calculate mean and """Cross GPU all reduce autograd operation for calculate mean and
variance in SyncBN. variance in SyncBN.
...@@ -94,6 +96,11 @@ class DataParallelModel(DataParallel): ...@@ -94,6 +96,11 @@ class DataParallelModel(DataParallel):
def gather(self, outputs, output_device): def gather(self, outputs, output_device):
return outputs return outputs
def replicate(self, module, device_ids):
modules = super(DataParallelModel, self).replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
class DataParallelCriterion(DataParallel): class DataParallelCriterion(DataParallel):
""" """
...@@ -181,3 +188,61 @@ def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices ...@@ -181,3 +188,61 @@ def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices
raise output raise output
outputs.append(output) outputs.append(output)
return outputs return outputs
###########################################################################
# Adapted from Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
class CallbackContext(object):
pass
def execute_replication_callbacks(modules):
"""
Execute an replication callback `__data_parallel_replicate__` on each module created
by original replication.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Note that, as all modules are isomorphism, we assign each sub-module with a context
(shared among multiple copies of this module on different devices).
Through this context, different copies can share some information.
We guarantee that the callback on the master copy (the first copy) will be called ahead
of calling the callback of any slave copies.
"""
master_copy = modules[0]
nr_modules = len(list(master_copy.modules()))
ctxs = [CallbackContext() for _ in range(nr_modules)]
for i, module in enumerate(modules):
for j, m in enumerate(module.modules()):
if hasattr(m, '__data_parallel_replicate__'):
m.__data_parallel_replicate__(ctxs[j], i)
def patch_replication_callback(data_parallel):
"""
Monkey-patch an existing `DataParallel` object. Add the replication callback.
Useful when you have customized `DataParallel` implementation.
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
> patch_replication_callback(sync_bn)
# this is equivalent to
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
"""
assert isinstance(data_parallel, DataParallel)
old_replicate = data_parallel.replicate
@functools.wraps(old_replicate)
def new_replicate(module, device_ids):
modules = old_replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
data_parallel.replicate = new_replicate
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment