syncbn.py 6.85 KB
Newer Older
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
1
2
3
4
5
6
7
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
Hang Zhang's avatar
sync BN  
Hang Zhang committed
8
## LICENSE file in the root directory of this source tree
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
9
10
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

Hang Zhang's avatar
sync BN  
Hang Zhang committed
11
"""Synchronized Cross-GPU Batch Normalization Module"""
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
12
13
import threading
import torch
Hang Zhang's avatar
sync BN  
Hang Zhang committed
14
15
from torch.nn import Module, Sequential, Conv1d, Conv2d, ConvTranspose2d, \
    ReLU, Sigmoid, MaxPool2d, AvgPool2d, AdaptiveAvgPool2d, Dropout2d, Linear
Hang Zhang's avatar
Hang Zhang committed
16
from torch.nn.modules.batchnorm import _BatchNorm
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
17

Hang Zhang's avatar
sync BN  
Hang Zhang committed
18
19
from ..functions import batchnormtrain, batchnormeval, sum_square
from ..parallel import allreduce
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
20

Hang Zhang's avatar
pylint  
Hang Zhang committed
21
22
23
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'Module', 'Sequential', 'Conv1d',
           'Conv2d', 'ConvTranspose2d', 'ReLU', 'Sigmoid', 'MaxPool2d', 'AvgPool2d',
           'AdaptiveAvgPool2d', 'Dropout2d', 'Linear']
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
24

Hang Zhang's avatar
Hang Zhang committed
25
class _SyncBatchNorm(_BatchNorm):
Hang Zhang's avatar
pylint  
Hang Zhang committed
26
    # pylint: disable=access-member-before-definition
Hang Zhang's avatar
Hang Zhang committed
27
28
29
    def __init__(self, num_features, eps=1e-5, momentum=0.1, **kwargs):
        super(_SyncBatchNorm, self).__init__(num_features, eps=1e-5, momentum=0.1, **kwargs)
        # syncBN
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
30
        self.writelock = threading.Lock()
Hang Zhang's avatar
sync BN  
Hang Zhang committed
31
        nGPUs = torch.cuda.device_count()
Hang Zhang's avatar
Hang Zhang committed
32
        self.sharedT = SharedTensor(nGPUs)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
33
34

    def forward(self, input):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
35
        self._check_input_dim(input)
Hang Zhang's avatar
Hang Zhang committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
        input_shape = input.size()
        input = input.view(input_shape[0], self.num_features, -1)
        if not self.training:
            std = (self.running_var.clamp(self.eps)).sqrt()
            output = batchnormeval(input, self.weight, self.bias, self.running_mean, std)
            return output.view(input_shape)
        # get global sum(x) and sum(x^2)
        xsum, xsquare = self.sharedT(sum_square(input.unsqueeze(3)))
        # calculate mean, var
        N = len(self.sharedT) * input.size(0) * input.size(2)
        mean = xsum / N
        sumvar = xsquare - xsum * xsum / N
        unbias_var = sumvar / (N - 1)
        bias_var = sumvar / N
        std = bias_var.clamp(self.eps).sqrt()
        # update running_mean and var
        self.running_mean = (1-self.momentum) * self.running_mean + self.momentum * mean.data
        self.running_var = (1-self.momentum) * self.running_var + self.momentum * unbias_var.data
        # forward
        return batchnormtrain(input, self.weight, self.bias, mean, std).view(input_shape)


class BatchNorm1d(_SyncBatchNorm):
    r"""Please see the docs in :class:`encoding.nn.BatchNorm2d`"""
    def _check_input_dim(self, input):
        if input.dim() != 2 and input.dim() != 3:
            raise ValueError('expected 2D or 3D input (got {}D input)'
                             .format(input.dim()))
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
64

Hang Zhang's avatar
Hang Zhang committed
65
class BatchNorm2d(_SyncBatchNorm):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
66
    r"""Cross-GPU Synchronized Batch normalization (SyncBN)
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
67

Hang Zhang's avatar
sync BN  
Hang Zhang committed
68
69
70
    Standard BN [1]_ implementation only normalize the data within each device.
    SyncBN normalizes the input within the whole mini-batch.
    We follow the sync-onece implmentation described in the paper [2]_ .
Hang Zhang's avatar
Hang Zhang committed
71
    Please see the design idea in the `notes <./notes/syncbn.html>`_.
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
72
73
74

    .. math::

Hang Zhang's avatar
sync BN  
Hang Zhang committed
75
        y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
76

Hang Zhang's avatar
Hang Zhang committed
77
    The mean and standard-deviation are calculated per-channel over
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
78
79
80
81
82
83
84
85
    the mini-batches and gamma and beta are learnable parameter vectors
    of size C (where C is the input size).

    During training, this layer keeps a running estimate of its computed mean
    and variance. The running sum is kept with a default momentum of 0.1.

    During evaluation, this running mean/variance is used for normalization.

Hang Zhang's avatar
sync BN  
Hang Zhang committed
86
87
88
    Because the BatchNorm is done over the `C` dimension, computing statistics
    on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm

Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
89
90
91
92
93
94
95
    Args:
        num_features: num_features from an expected input of
            size batch_size x num_features x height x width
        eps: a value added to the denominator for numerical stability.
            Default: 1e-5
        momentum: the value used for the running_mean and running_var
            computation. Default: 0.1
Hang Zhang's avatar
sync BN  
Hang Zhang committed
96
97
        affine: a boolean value that when set to ``True``, gives the layer learnable
            affine parameters. Default: ``True``
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
98
99
100
101
102

    Shape:
        - Input: :math:`(N, C, H, W)`
        - Output: :math:`(N, C, H, W)` (same shape as input)

Hang Zhang's avatar
sync BN  
Hang Zhang committed
103
104
105
106
    Reference:
        .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." *ICML 2015*
        .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018*

Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
107
    Examples:
Hang Zhang's avatar
sync BN  
Hang Zhang committed
108
109
        >>> # Use exactly the same as standard BatchNrom2d
        >>> m = nn.BatchNorm2d(100)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
110
111
112
113
114
115
116
        >>> output = m(input)
    """
    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError('expected 4D input (got {}D input)'
                             .format(input.dim()))

Hang Zhang's avatar
Hang Zhang committed
117
118
119
120
121
122
class BatchNorm3d(_SyncBatchNorm):
    r"""Please see the docs in :class:`encoding.nn.BatchNorm2d`"""
    def _check_input_dim(self, input):
        if input.dim() != 5:
            raise ValueError('expected 5D input (got {}D input)'
                             .format(input.dim()))
Hang Zhang's avatar
sync BN  
Hang Zhang committed
123
124

class SharedTensor(object):
Hang Zhang's avatar
Hang Zhang committed
125
    """Shared Tensor for cross GPU communication
Hang Zhang's avatar
sync BN  
Hang Zhang committed
126
127
128
129
130
131
132
133
134
135
136
137
    """
    def __init__(self, nGPUs):
        self.mutex = threading.Lock()
        self.all_tasks_done = threading.Condition(self.mutex)
        self.nGPUs = nGPUs
        self._clear()

    def _clear(self):
        self.list = []
        self.push_tasks = self.nGPUs
        self.reduce_tasks = self.nGPUs

Hang Zhang's avatar
Hang Zhang committed
138
139
    def __call__(self, *inputs):
        # push from device
Hang Zhang's avatar
sync BN  
Hang Zhang committed
140
141
142
        with self.mutex:
            if self.push_tasks == 0:
                self._clear()
Hang Zhang's avatar
Hang Zhang committed
143
144
            self.list.extend(list(*inputs))
            idx = self.nGPUs - self.push_tasks
Hang Zhang's avatar
sync BN  
Hang Zhang committed
145
146
147
148
149
150
            self.push_tasks -= 1
        with self.all_tasks_done:
            if self.push_tasks == 0:
                self.all_tasks_done.notify_all()
            while self.push_tasks:
                self.all_tasks_done.wait()
Hang Zhang's avatar
Hang Zhang committed
151
        # pull from device
Hang Zhang's avatar
sync BN  
Hang Zhang committed
152
153
        with self.mutex:
            if self.reduce_tasks == self.nGPUs:
Hang Zhang's avatar
Hang Zhang committed
154
155
                assert(len(self.list) == 2 * self.nGPUs)
                self.list = allreduce(2, *self.list)
Hang Zhang's avatar
sync BN  
Hang Zhang committed
156
157
158
159
160
161
162
163
                self.reduce_tasks -= 1
            else:
                self.reduce_tasks -= 1
        with self.all_tasks_done:
            if self.reduce_tasks == 0:
                self.all_tasks_done.notify_all()
            while self.reduce_tasks:
                self.all_tasks_done.wait()
Hang Zhang's avatar
Hang Zhang committed
164
165
        # all reduce done
        return self.list[2*idx], self.list[2*idx+1]
Hang Zhang's avatar
sync BN  
Hang Zhang committed
166
167

    def __len__(self):
Hang Zhang's avatar
Hang Zhang committed
168
        return self.nGPUs
Hang Zhang's avatar
sync BN  
Hang Zhang committed
169
170
171

    def __repr__(self):
        return ('SharedTensor')