"git@developer.sourcefind.cn:change/sglang.git" did not exist on "a23c30205d18a7953f63930b95686b50438f8736"
syncbn.py 9.54 KB
Newer Older
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
1
2
3
4
5
6
7
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
Hang Zhang's avatar
sync BN  
Hang Zhang committed
8
## LICENSE file in the root directory of this source tree
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
9
10
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

Hang Zhang's avatar
sync BN  
Hang Zhang committed
11
"""Synchronized Cross-GPU Batch Normalization Module"""
Zhang's avatar
Zhang committed
12
import collections
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
13
14
import threading
import torch
Hang Zhang's avatar
sync BN  
Hang Zhang committed
15
from torch.nn import Module, Sequential, Conv1d, Conv2d, ConvTranspose2d, \
Zhang's avatar
Zhang committed
16
17
    ReLU, Sigmoid, MaxPool2d, AvgPool2d, AdaptiveAvgPool2d, Dropout2d, Linear, \
    DataParallel
Hang Zhang's avatar
Hang Zhang committed
18
from torch.nn.modules.batchnorm import _BatchNorm
Zhang's avatar
Zhang committed
19
20
from torch.nn.functional import batch_norm
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
21

Zhang's avatar
Zhang committed
22
from ..functions import *
Hang Zhang's avatar
sync BN  
Hang Zhang committed
23
from ..parallel import allreduce
Zhang's avatar
Zhang committed
24
from .comm import SyncMaster
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
25

Zhang's avatar
v0.4.2  
Zhang committed
26

Hang Zhang's avatar
pylint  
Hang Zhang committed
27
28
29
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'Module', 'Sequential', 'Conv1d',
           'Conv2d', 'ConvTranspose2d', 'ReLU', 'Sigmoid', 'MaxPool2d', 'AvgPool2d',
           'AdaptiveAvgPool2d', 'Dropout2d', 'Linear']
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
30

Hang Zhang's avatar
Hang Zhang committed
31
class _SyncBatchNorm(_BatchNorm):
Zhang's avatar
v0.4.2  
Zhang committed
32
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
Zhang's avatar
Zhang committed
33
34
        super(_SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)

Zhang's avatar
Zhang committed
35
        self._sync_master = SyncMaster(self._data_parallel_master)
Zhang's avatar
Zhang committed
36
37
        self._parallel_id = None
        self._slave_pipe = None
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
38
39

    def forward(self, input):
Zhang's avatar
v0.4.2  
Zhang committed
40
        if not self.training:
Zhang's avatar
Zhang committed
41
42
43
44
            return batch_norm(
                input, self.running_mean, self.running_var, self.weight, self.bias,
                self.training, self.momentum, self.eps)

Zhang's avatar
Zhang committed
45
        # Resize the input to (B, C, -1).
Hang Zhang's avatar
Hang Zhang committed
46
        input_shape = input.size()
Zhang's avatar
v0.4.2  
Zhang committed
47
        input = input.view(input_shape[0], self.num_features, -1)
Zhang's avatar
Zhang committed
48
49
50
51
52
53

        # sum(x) and sum(x^2)
        N = input.size(0) * input.size(2)
        xsum, xsqsum = sum_square(input)

        # all-reduce for global sum(x) and sum(x^2)
Zhang's avatar
Zhang committed
54
55
56
57
58
        if self._parallel_id == 0:
            mean, inv_std = self._sync_master.run_master(_ChildMessage(xsum, xsqsum, N))
        else:
            mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(xsum, xsqsum, N))
        # forward
Zhang's avatar
v0.4.2  
Zhang committed
59
        return batchnormtrain(input, mean, 1.0/inv_std, self.weight, self.bias).view(input_shape)
Zhang's avatar
Zhang committed
60

Zhang's avatar
Zhang committed
61
62
    def __data_parallel_replicate__(self, ctx, copy_id):
        self._parallel_id = copy_id
Zhang's avatar
Zhang committed
63

Zhang's avatar
Zhang committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        # parallel_id == 0 means master device.
        if self._parallel_id == 0:
            ctx.sync_master = self._sync_master
        else:
            self._slave_pipe = ctx.sync_master.register_slave(copy_id)

    def _data_parallel_master(self, intermediates):
        """Reduce the sum and square-sum, compute the statistics, and broadcast it."""

        # Always using same "device order" makes the ReduceAdd operation faster.
        # Thanks to:: Tete Xiao (http://tetexiao.com/)
        intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())

        to_reduce = [i[1][:2] for i in intermediates]
        to_reduce = [j for i in to_reduce for j in i]  # flatten
        target_gpus = [i[1].sum.get_device() for i in intermediates]

        sum_size = sum([i[1].sum_size for i in intermediates])
        sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
        mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)

        broadcasted = Broadcast.apply(target_gpus, mean, inv_std)

        outputs = []
        for i, rec in enumerate(intermediates):
            outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))

        return outputs

    def _compute_mean_std(self, sum_, ssum, size):
        """Compute the mean and standard-deviation with sum and square-sum. This method
        also maintains the moving average on the master device."""
        assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
        mean = sum_ / size
        sumvar = ssum - sum_ * mean
        unbias_var = sumvar / (size - 1)
        bias_var = sumvar / size

        self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
        self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data

Zhang's avatar
v0.4.2  
Zhang committed
105
106
107
108
109
110
        return mean, (bias_var + self.eps) ** -0.5


# API adapted from https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
_ChildMessage = collections.namedtuple('Message', ['sum', 'ssum', 'sum_size'])
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
Hang Zhang's avatar
Hang Zhang committed
111
112
113
114
115
116
117
118


class BatchNorm1d(_SyncBatchNorm):
    r"""Please see the docs in :class:`encoding.nn.BatchNorm2d`"""
    def _check_input_dim(self, input):
        if input.dim() != 2 and input.dim() != 3:
            raise ValueError('expected 2D or 3D input (got {}D input)'
                             .format(input.dim()))
Zhang's avatar
Zhang committed
119
120
        super(BatchNorm2d, self)._check_input_dim(input)

Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
121

Hang Zhang's avatar
Hang Zhang committed
122
class BatchNorm2d(_SyncBatchNorm):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
123
    r"""Cross-GPU Synchronized Batch normalization (SyncBN)
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
124

Zhang's avatar
Zhang committed
125
    Standard BN [1]_ implementation only normalize the data within each device (GPU).
Hang Zhang's avatar
sync BN  
Hang Zhang committed
126
127
    SyncBN normalizes the input within the whole mini-batch.
    We follow the sync-onece implmentation described in the paper [2]_ .
Hang Zhang's avatar
Hang Zhang committed
128
    Please see the design idea in the `notes <./notes/syncbn.html>`_.
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
129

Zhang's avatar
Zhang committed
130
    .. note::
Zhang's avatar
Zhang committed
131
132
133
        We adapt the awesome python API from another `PyTorch SyncBN Implementation
        <https://github.com/vacancy/Synchronized-BatchNorm-PyTorch>`_ and provide
        efficient CUDA backend.
Zhang's avatar
Zhang committed
134

Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
135
136
    .. math::

Hang Zhang's avatar
sync BN  
Hang Zhang committed
137
        y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
138

Hang Zhang's avatar
Hang Zhang committed
139
    The mean and standard-deviation are calculated per-channel over
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
140
141
142
143
144
145
146
147
    the mini-batches and gamma and beta are learnable parameter vectors
    of size C (where C is the input size).

    During training, this layer keeps a running estimate of its computed mean
    and variance. The running sum is kept with a default momentum of 0.1.

    During evaluation, this running mean/variance is used for normalization.

Hang Zhang's avatar
sync BN  
Hang Zhang committed
148
149
150
    Because the BatchNorm is done over the `C` dimension, computing statistics
    on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm

Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
151
152
153
154
155
156
157
    Args:
        num_features: num_features from an expected input of
            size batch_size x num_features x height x width
        eps: a value added to the denominator for numerical stability.
            Default: 1e-5
        momentum: the value used for the running_mean and running_var
            computation. Default: 0.1
Hang Zhang's avatar
sync BN  
Hang Zhang committed
158
159
        affine: a boolean value that when set to ``True``, gives the layer learnable
            affine parameters. Default: ``True``
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
160
161
162
163
164

    Shape:
        - Input: :math:`(N, C, H, W)`
        - Output: :math:`(N, C, H, W)` (same shape as input)

Hang Zhang's avatar
sync BN  
Hang Zhang committed
165
166
167
168
    Reference:
        .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." *ICML 2015*
        .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018*

Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
169
    Examples:
Zhang's avatar
Zhang committed
170
171
        >>> m = BatchNorm2d(100)
        >>> net = torch.nn.DataParallel(m)
Zhang's avatar
Zhang committed
172
        >>> encoding.parallel.patch_replication_callback(net)
Zhang's avatar
Zhang committed
173
        >>> output = net(input)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
174
175
176
177
178
    """
    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError('expected 4D input (got {}D input)'
                             .format(input.dim()))
Zhang's avatar
Zhang committed
179
180
        super(BatchNorm2d, self)._check_input_dim(input)

Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
181

Hang Zhang's avatar
Hang Zhang committed
182
183
184
185
186
187
class BatchNorm3d(_SyncBatchNorm):
    r"""Please see the docs in :class:`encoding.nn.BatchNorm2d`"""
    def _check_input_dim(self, input):
        if input.dim() != 5:
            raise ValueError('expected 5D input (got {}D input)'
                             .format(input.dim()))
Zhang's avatar
Zhang committed
188
189
        super(BatchNorm3d, self)._check_input_dim(input)

Hang Zhang's avatar
sync BN  
Hang Zhang committed
190
191

class SharedTensor(object):
Zhang's avatar
Zhang committed
192
    """Shared Tensor for cross GPU all reduce operation"""
Zhang's avatar
v0.4.2  
Zhang committed
193
    def __init__(self, nGPUs):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
194
195
196
197
198
199
        self.mutex = threading.Lock()
        self.all_tasks_done = threading.Condition(self.mutex)
        self.nGPUs = nGPUs
        self._clear()

    def _clear(self):
Zhang's avatar
Zhang committed
200
201
        self.N = 0
        self.dict = {}
Hang Zhang's avatar
sync BN  
Hang Zhang committed
202
203
204
        self.push_tasks = self.nGPUs
        self.reduce_tasks = self.nGPUs

Zhang's avatar
v0.4.2  
Zhang committed
205
    def push(self, *inputs):
Hang Zhang's avatar
Hang Zhang committed
206
        # push from device
Hang Zhang's avatar
sync BN  
Hang Zhang committed
207
208
209
        with self.mutex:
            if self.push_tasks == 0:
                self._clear()
Zhang's avatar
Zhang committed
210
211
212
213
            self.N += inputs[0]
            igpu = inputs[1]
            self.dict[igpu] = inputs[2:]
            #idx = self.nGPUs - self.push_tasks
Hang Zhang's avatar
sync BN  
Hang Zhang committed
214
215
216
217
218
219
            self.push_tasks -= 1
        with self.all_tasks_done:
            if self.push_tasks == 0:
                self.all_tasks_done.notify_all()
            while self.push_tasks:
                self.all_tasks_done.wait()
Zhang's avatar
v0.4.2  
Zhang committed
220
221

    def pull(self, igpu):
Hang Zhang's avatar
Hang Zhang committed
222
        # pull from device
Hang Zhang's avatar
sync BN  
Hang Zhang committed
223
        with self.mutex:
Zhang's avatar
Zhang committed
224
225
226
227
            if igpu == 0:
                assert(len(self.dict) == self.nGPUs)
                # flatten the tensors
                self.list = [t for i in range(len(self.dict)) for t in self.dict[i]]
Zhang's avatar
v0.4.2  
Zhang committed
228
                self.outlist = allreduce(2, *self.list)
Hang Zhang's avatar
sync BN  
Hang Zhang committed
229
230
231
232
233
234
235
236
                self.reduce_tasks -= 1
            else:
                self.reduce_tasks -= 1
        with self.all_tasks_done:
            if self.reduce_tasks == 0:
                self.all_tasks_done.notify_all()
            while self.reduce_tasks:
                self.all_tasks_done.wait()
Hang Zhang's avatar
Hang Zhang committed
237
        # all reduce done
Zhang's avatar
Zhang committed
238
        return self.N, self.outlist[2*igpu], self.outlist[2*igpu+1]
Hang Zhang's avatar
sync BN  
Hang Zhang committed
239
240

    def __len__(self):
Hang Zhang's avatar
Hang Zhang committed
241
        return self.nGPUs
Hang Zhang's avatar
sync BN  
Hang Zhang committed
242
243
244

    def __repr__(self):
        return ('SharedTensor')
Zhang's avatar
Zhang committed
245