syncbn.py 10 KB
Newer Older
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
1
2
3
4
5
6
7
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
Hang Zhang's avatar
sync BN  
Hang Zhang committed
8
## LICENSE file in the root directory of this source tree
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
9
10
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

Hang Zhang's avatar
sync BN  
Hang Zhang committed
11
"""Synchronized Cross-GPU Batch Normalization Module"""
Hang Zhang's avatar
Hang Zhang committed
12
13
14
15
16
17
import warnings
try:
    from queue import Queue
except ImportError:
    from Queue import Queue

Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
18
import torch
Hang Zhang's avatar
Hang Zhang committed
19
from torch.nn.modules.batchnorm import _BatchNorm
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
20

Hang Zhang's avatar
Hang Zhang committed
21
from ..utils.misc import EncodingDeprecationWarning
Zhang's avatar
Zhang committed
22
from ..functions import *
Zhang's avatar
Zhang committed
23
24


Hang Zhang's avatar
Hang Zhang committed
25
__all__ = ['DistSyncBatchNorm', 'SyncBatchNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']
Zhang's avatar
Zhang committed
26

Hang Zhang's avatar
Hang Zhang committed
27
class DistSyncBatchNorm(_BatchNorm):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
28
    r"""Cross-GPU Synchronized Batch normalization (SyncBN)
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
29

Zhang's avatar
Zhang committed
30
    Standard BN [1]_ implementation only normalize the data within each device (GPU).
Hang Zhang's avatar
sync BN  
Hang Zhang committed
31
32
    SyncBN normalizes the input within the whole mini-batch.
    We follow the sync-onece implmentation described in the paper [2]_ .
Hang Zhang's avatar
Hang Zhang committed
33
    Please see the design idea in the `notes <./notes/syncbn.html>`_.
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
34
35
36

    .. math::

Hang Zhang's avatar
sync BN  
Hang Zhang committed
37
        y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
38

Hang Zhang's avatar
Hang Zhang committed
39
    The mean and standard-deviation are calculated per-channel over
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
40
41
42
43
44
45
46
47
    the mini-batches and gamma and beta are learnable parameter vectors
    of size C (where C is the input size).

    During training, this layer keeps a running estimate of its computed mean
    and variance. The running sum is kept with a default momentum of 0.1.

    During evaluation, this running mean/variance is used for normalization.

Hang Zhang's avatar
sync BN  
Hang Zhang committed
48
49
50
    Because the BatchNorm is done over the `C` dimension, computing statistics
    on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm

Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
51
52
53
54
55
56
57
    Args:
        num_features: num_features from an expected input of
            size batch_size x num_features x height x width
        eps: a value added to the denominator for numerical stability.
            Default: 1e-5
        momentum: the value used for the running_mean and running_var
            computation. Default: 0.1
Hang Zhang's avatar
Hang Zhang committed
58
59
60
61
62
63
        sync: a boolean value that when set to ``True``, synchronize across
            different gpus. Default: ``True``
        activation : str
            Name of the activation functions, one of: `leaky_relu` or `none`.
        slope : float
            Negative slope for the `leaky_relu` activation.
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
64
65
66
67
68

    Shape:
        - Input: :math:`(N, C, H, W)`
        - Output: :math:`(N, C, H, W)` (same shape as input)

Hang Zhang's avatar
sync BN  
Hang Zhang committed
69
70
71
72
    Reference:
        .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." *ICML 2015*
        .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018*

Hang Zhang's avatar
Hang Zhang committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    Examples:
        >>> m = DistSyncBatchNorm(100)
        >>> net = torch.nn.parallel.DistributedDataParallel(m)
        >>> output = net(input)
    """
    def __init__(self, num_features, eps=1e-5, momentum=0.1, process_group=None):
        super(DistSyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=True, track_running_stats=True)
        self.process_group = process_group

    def forward(self, x):
        need_sync = self.training or not self.track_running_stats
        process_group = None
        if need_sync:
            process_group = torch.distributed.group.WORLD
            if self.process_group:
                process_group = self.process_group
            world_size = torch.distributed.get_world_size(process_group)
            need_sync = world_size > 1

        # Resize the input to (B, C, -1).
        input_shape = x.size()
        x = x.view(input_shape[0], self.num_features, -1)
        #def forward(ctx, x, gamma, beta, running_mean, running_var, eps, momentum, training, process_group):
        y = dist_syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var,
                               self.eps, self.momentum, self.training, process_group)

        #_var = _exs - _ex ** 2
        #running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex)
        #running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var)
        return y.view(input_shape)


class SyncBatchNorm(_BatchNorm):
    r"""Cross-GPU Synchronized Batch normalization (SyncBN)

    Standard BN [1]_ implementation only normalize the data within each device (GPU).
    SyncBN normalizes the input within the whole mini-batch.
    We follow the sync-onece implmentation described in the paper [2]_ .
    Please see the design idea in the `notes <./notes/syncbn.html>`_.

    .. math::

        y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta

    The mean and standard-deviation are calculated per-channel over
    the mini-batches and gamma and beta are learnable parameter vectors
    of size C (where C is the input size).

    During training, this layer keeps a running estimate of its computed mean
    and variance. The running sum is kept with a default momentum of 0.1.

    During evaluation, this running mean/variance is used for normalization.

    Because the BatchNorm is done over the `C` dimension, computing statistics
    on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm

    Args:
        num_features: num_features from an expected input of
            size batch_size x num_features x height x width
        eps: a value added to the denominator for numerical stability.
            Default: 1e-5
        momentum: the value used for the running_mean and running_var
            computation. Default: 0.1
        sync: a boolean value that when set to ``True``, synchronize across
            different gpus. Default: ``True``
        activation : str
            Name of the activation functions, one of: `leaky_relu` or `none`.
        slope : float
            Negative slope for the `leaky_relu` activation.

    Shape:
        - Input: :math:`(N, C, H, W)`
        - Output: :math:`(N, C, H, W)` (same shape as input)

Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
147
    Examples:
Hang Zhang's avatar
Hang Zhang committed
148
        >>> m = SyncBatchNorm(100)
Zhang's avatar
Zhang committed
149
150
        >>> net = torch.nn.DataParallel(m)
        >>> output = net(input)
Hang Zhang's avatar
Hang Zhang committed
151
152
        >>> # for Inpace ABN
        >>> ABN = partial(SyncBatchNorm, activation='leaky_relu', slope=0.01, sync=True, inplace=True)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
153
    """
Zhang's avatar
Zhang committed
154

Hang Zhang's avatar
Hang Zhang committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    def __init__(self, num_features, eps=1e-5, momentum=0.1, sync=True, activation="none", slope=0.01,
                 inplace=True):
        super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=True)
        self.activation = activation
        self.inplace = False if activation == 'none' else inplace
        #self.inplace = inplace
        self.slope = slope
        self.devices = list(range(torch.cuda.device_count()))
        self.sync = sync if len(self.devices) > 1 else False
        # Initialize queues
        self.worker_ids = self.devices[1:]
        self.master_queue = Queue(len(self.worker_ids))
        self.worker_queues = [Queue(1) for _ in self.worker_ids]
        # running_exs
        #self.register_buffer('running_exs', torch.ones(num_features))

Hang Zhang's avatar
Hang Zhang committed
171
172
173
    def _check_input_dim(self, x):
        pass

Hang Zhang's avatar
Hang Zhang committed
174
    def forward(self, x):
Hang Zhang's avatar
Hang Zhang committed
175
176
        if not self.training:
            return super().forward(x)
Hang Zhang's avatar
Hang Zhang committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
        # Resize the input to (B, C, -1).
        input_shape = x.size()
        x = x.view(input_shape[0], self.num_features, -1)
        if x.get_device() == self.devices[0]:
            # Master mode
            extra = {
                "is_master": True,
                "master_queue": self.master_queue,
                "worker_queues": self.worker_queues,
                "worker_ids": self.worker_ids
            }
        else:
            # Worker mode
            extra = {
                "is_master": False,
                "master_queue": self.master_queue,
                "worker_queue": self.worker_queues[self.worker_ids.index(x.get_device())]
            }
        if self.inplace:
            return inp_syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var,
                                     extra, self.sync, self.training, self.momentum, self.eps,
                                     self.activation, self.slope).view(input_shape)
        else:
            return syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var,
                                 extra, self.sync, self.training, self.momentum, self.eps,
                                 self.activation, self.slope).view(input_shape)

    def extra_repr(self):
        if self.activation == 'none':
            return 'sync={}'.format(self.sync)
        else:
            return 'sync={}, act={}, slope={}, inplace={}'.format(
                self.sync, self.activation, self.slope, self.inplace
            )

class BatchNorm1d(SyncBatchNorm):
    r"""
    .. warning::
        BatchNorm1d is deprecated in favor of :class:`encoding.nn.SyncBatchNorm`.
    """
    def __init__(self, *args, **kwargs):
        warnings.warn("encoding.nn.{} is now deprecated in favor of encoding.nn.{}."
                      .format('BatchNorm1d', SyncBatchNorm.__name__), EncodingDeprecationWarning)
        super(BatchNorm1d, self).__init__(*args, **kwargs)

class BatchNorm2d(SyncBatchNorm):
    r"""
    .. warning::
        BatchNorm2d is deprecated in favor of :class:`encoding.nn.SyncBatchNorm`.
    """
    def __init__(self, *args, **kwargs):
        warnings.warn("encoding.nn.{} is now deprecated in favor of encoding.nn.{}."
                      .format('BatchNorm2d', SyncBatchNorm.__name__), EncodingDeprecationWarning)
        super(BatchNorm2d, self).__init__(*args, **kwargs)

class BatchNorm3d(SyncBatchNorm):
    r"""
    .. warning::
        BatchNorm3d is deprecated in favor of :class:`encoding.nn.SyncBatchNorm`.
    """
    def __init__(self, *args, **kwargs):
        warnings.warn("encoding.nn.{} is now deprecated in favor of encoding.nn.{}."
                      .format('BatchNorm3d', SyncBatchNorm.__name__), EncodingDeprecationWarning)
        super(BatchNorm3d, self).__init__(*args, **kwargs)