syncbn.py 17.7 KB
Newer Older
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree 
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

import math
import threading
import torch
import torch.cuda.comm as comm
from torch.autograd import Variable
from torch.nn import Module, Sequential
from torch.nn import functional as F
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _single, _pair, _triple
from torch.nn.parallel.scatter_gather import scatter, scatter_kwargs, \
    gather

from ..functions import view_each, multi_each, sum_each, batchnormtrain, batchnormeval, sum_square 
from ..parallel import my_data_parallel, Broadcast, AllReduce

__all__ = ['BatchNorm1d', 'BatchNorm2d']

class BatchNorm1d(Module):
    r"""Synchronized Batch Normalization 1d
Hang Zhang's avatar
path  
Hang Zhang committed
30
31

    `Implementation ideas <./notes/syncbn.html>`_. Please use compatible :class:`encoding.parallel.SelfDataParallel` and :class:`encoding.nn`
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
32

Zhang's avatar
v0.2.0  
Zhang committed
33
34
35
    Reference:

        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, Amit Agrawal. “Context Encoding for Semantic Segmentation. CVPR 2018
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
36

Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    Applies Batch Normalization over a 2d or 3d input that is seen as a
    mini-batch.

    .. math::

        y = \frac{x - \mu[x]}{ \sqrt{var[x] + \epsilon}} * \gamma + \beta

    The mean and standard-deviation are calculated per-dimension over
    the mini-batches and gamma and beta are learnable parameter vectors
    of size C (where C is the input size).

    During training, this layer keeps a running estimate of its computed mean
    and variance. The running sum is kept with a default momentum of 0.1.

    During evaluation, this running mean/variance is used for normalization.

    Args:
        num_features: num_features from an expected input of size
            `batch_size x num_features [x width]`
        eps: a value added to the denominator for numerical stability.
            Default: 1e-5
        momentum: the value used for the running_mean and running_var
            computation. Default: 0.1
        affine: a boolean value that when set to true, gives the layer 
            learnable affine parameters. Default: True

    Shape:
        - Input: :math:`(N, C)` or :math:`(N, C, L)`
        - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)

    Examples:
        >>> m = encoding.nn.BatchNorm1d(100).cuda()
        >>> input = autograd.Variable(torch.randn(20, 100)).cuda()
        >>> output = m(input)
    """
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
        super(BatchNorm1d, self).__init__()
        self.num_features = num_features
        self.affine = affine
        self.eps = eps
        self.momentum = momentum
        if self.affine:
            self.weight = Parameter(torch.Tensor(num_features))
            self.bias = Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
        self.reset_parameters()
        self.writelock = threading.Lock()

    def reset_parameters(self):
        self.running_mean.zero_()
        self.running_var.fill_(1)
        if self.affine:
            self.weight.data.uniform_()
            self.bias.data.zero_()

    def __repr__(self):
        return ('{name}({num_features}, eps={eps}, momentum={momentum},'
                ' affine={affine})'
                .format(name=self.__class__.__name__, **self.__dict__))

    def _check_input_dim(self, input):
        if input.dim() != 3:
            raise ValueError('expected 3D input (got {}D input)'
                             .format(input.dim()))

    def forward(self, input):
        if isinstance(input, Variable):
            self._check_input_dim(input)
            if self.training:
                xsum, xsquare = sum_square(input.unsqueeze(3))
                N = input.size(0)*input.size(2)
                mean = xsum / N
                sumvar = xsquare - xsum * xsum / N
                unbias_var = sumvar / (N - 1)
                std = (sumvar / N + self.eps).sqrt()
                # update running_mean and var
                self.running_mean = (1-self.momentum) * self.running_mean \
                    + self.momentum * mean.data
                self.running_var = (1-self.momentum) * self.running_var + \
                    self.momentum * unbias_var.data
                # forward
                output = batchnormtrain(
                    input, self.weight, 
                    self.bias, mean, 
                    std)
                return output
            else:
                var_mean = Variable(self.running_mean, requires_grad=False)
                bias_var = Variable(self.running_var, requires_grad=False)
                std = (bias_var + self.eps).sqrt()
                return batchnormeval(
                    input, self.weight, self.bias, var_mean, std)

        elif isinstance(input, tuple) or isinstance(input, list):
            self._check_input_dim(input[0])
            # if evaluation, do it simple
            if not self.training:
                return my_data_parallel(self, input)
            if len(input) == 1:
                return self.forward(input[0])
            # calculate mean and var using multithreading
            all_sum, all_xsquare = {},{}
            def _worker(i, x, lock):
                try:
                    with torch.cuda.device_of(x):
                        xsum, xsquare = sum_square(x.unsqueeze(3))
                    with lock:
                        all_sum[i] = xsum 
                        all_xsquare[i] = xsquare 
                except Exception as e:
                    with lock:
                        all_sum[i] = e
                        all_xsquare[i] = e
            threads = [threading.Thread(target=_worker,
                                        args=(i, x, self.writelock))
                        for i, x in enumerate(input)]
            for thread in threads:
                thread.start()
            for thread in threads:
                thread.join()
            # convert to list
            def _to_list(x):
                outputs = []
                for i in range(len(x)):
                    outputs.append(x[i])
                return outputs
            
            all_sum = _to_list(all_sum)
            all_xsquare = _to_list(all_xsquare)
            xsums = AllReduce()(*all_sum)
            xsquares = AllReduce()(*all_xsquare)

            nGPUs = len(input)
            N = nGPUs * input[0].size(0)*input[0].size(2)
            assert(N>1)
            xmean = xsums[0].data / N
            unbias_var = (xsquares[0].data - N * xmean * xmean) / (N-1) 
            # update running_mean and var
            self.running_mean = (1-self.momentum) * self.running_mean \
                + self.momentum * xmean
            self.running_var = (1-self.momentum) * self.running_var + \
                self.momentum * unbias_var
            # Broadcast the weight, bias, mean, std
            device_ids = list(range(torch.cuda.device_count()))
            weights = Broadcast(device_ids[:len(input)])(self.weight) 
            biases = Broadcast(device_ids[:len(input)])(self.bias)
            # parallel-apply
            results = {}
            def _worker_bn(i, x, xsum, xsquare, weight, bias, lock):
                var_input = _get_a_var(x)
                mean = xsum / N
                std  = (xsquare / N - mean * mean + self.eps).sqrt()
                try:
                    with torch.cuda.device_of(var_input):
                        result = batchnormtrain(
                            x, weight, bias, mean, std)
                    with lock: 
                        results[i] = result
                except Exception as e:
                    with lock:
                        results[i] = e
            threads = [threading.Thread(target=_worker_bn,
                                        args=(i, x, xsum, xsquare, weight, 
                                              bias, self.writelock)
                                       )
                        for i,( x, xsum, xsquare, weight, bias) in 
                        enumerate(zip(input, xsums, xsquares, 
                                      weights, biases))]
            for thread in threads:
                thread.start()
            for thread in threads:
                thread.join()
            outputs = []
            for i in range(len(results)):
                output = results[i]
                if isinstance(output, Exception):
                    raise output
                outputs.append(output)
            return outputs
        else:
            raise RuntimeError('unknown input type')


class BatchNorm2d(Module):
    r"""Synchronized Batch Normalization 2d
Hang Zhang's avatar
path  
Hang Zhang committed
226
227

    `Implementation ideas <./notes/syncbn.html>`_. Please use compatible :class:`encoding.parallel.SelfDataParallel` and :class:`encoding.nn`. 
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
228

Zhang's avatar
v0.2.0  
Zhang committed
229
230
231
    Reference:

        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, Amit Agrawal. “Context Encoding for Semantic Segmentation. CVPR 2018
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
232

Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
    Applies Batch Normalization over a 4d input that is seen as a mini-batch
    of 3d inputs

    .. math::

        y = \frac{x - \mu[x]}{ \sqrt{var[x] + \epsilon}} * \gamma + \beta

    The mean and standard-deviation are calculated per-dimension over
    the mini-batches and gamma and beta are learnable parameter vectors
    of size C (where C is the input size).

    During training, this layer keeps a running estimate of its computed mean
    and variance. The running sum is kept with a default momentum of 0.1.

    During evaluation, this running mean/variance is used for normalization.

    Args:
        num_features: num_features from an expected input of
            size batch_size x num_features x height x width
        eps: a value added to the denominator for numerical stability.
            Default: 1e-5
        momentum: the value used for the running_mean and running_var
            computation. Default: 0.1
        affine: a boolean value that when set to true, gives the layer learnable
            affine parameters. Default: True

    Shape:
        - Input: :math:`(N, C, H, W)`
        - Output: :math:`(N, C, H, W)` (same shape as input)

    Examples:
        >>> m = encoding.nn.BatchNorm2d(100).cuda()
        >>> input = autograd.Variable(torch.randn(20, 100, 35, 45)).cuda()
        >>> output = m(input)
    """
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
        super(BatchNorm2d, self).__init__()
        self.num_features = num_features
        self.affine = affine
        self.eps = eps
        self.momentum = momentum
        if self.affine:
            self.weight = Parameter(torch.Tensor(num_features))
            self.bias = Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
        self.reset_parameters()
        self.writelock = threading.Lock()

    def reset_parameters(self):
        self.running_mean.zero_()
        self.running_var.fill_(1)
        if self.affine:
            self.weight.data.uniform_()
            self.bias.data.zero_()

    def __repr__(self):
        return ('{name}({num_features}, eps={eps}, momentum={momentum},'
                ' affine={affine})'
                .format(name=self.__class__.__name__, **self.__dict__))

    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError('expected 4D input (got {}D input)'
                             .format(input.dim()))

    def forward(self, input):
        if isinstance(input, Variable):
            self._check_input_dim(input)
            if self.training:
                xsum, xsquare = sum_square(input)
                N = input.size(0)*input.size(2)*input.size(3)
                mean = xsum / N
                sumvar = xsquare - xsum * xsum / N
                unbias_var = sumvar / (N - 1)
                std = (sumvar / N + self.eps).sqrt()
                # update running_mean and var
                self.running_mean = (1-self.momentum) * self.running_mean \
                    + self.momentum * mean.data
                self.running_var = (1-self.momentum) * self.running_var + \
                    self.momentum * unbias_var.data
                # forward
                B, C, H, W = input.size()
                output = batchnormtrain(
                    input.view(B,C,-1).contiguous(), self.weight, 
                    self.bias, mean, 
                    std)
                return output.view(B, C, H, W)
            else:
                var_mean = Variable(self.running_mean, requires_grad=False)
                bias_var = Variable(self.running_var, requires_grad=False)
                std = (bias_var + self.eps).sqrt()
                B, C, H, W = input.size()
                return batchnormeval(
                    input.view(B,C,-1).contiguous(), 
                    self.weight, self.bias, var_mean, 
                    std).view(B, C, H, W)

        elif isinstance(input, tuple) or isinstance(input, list):
            self._check_input_dim(input[0])
            # if evaluation, do it simple
            if not self.training:
                return my_data_parallel(self, input)
            if len(input) == 1:
                return self.forward(input[0])
            # calculate mean and var using multithreading
            all_sum, all_xsquare = {},{}
            def _worker(i, x, lock):
                try:
                    with torch.cuda.device_of(x):
                        xsum, xsquare = sum_square(x)
                    with lock:
                        all_sum[i] = xsum 
                        all_xsquare[i] = xsquare 
                except Exception as e:
                    with lock:
                        all_sum[i] = e
                        all_xsquare[i] = e
            threads = [threading.Thread(target=_worker,
                                        args=(i, x, self.writelock))
                        for i, x in enumerate(input)]
            for thread in threads:
                thread.start()
            for thread in threads:
                thread.join()
            # convert to list
            def _to_list(x):
                outputs = []
                for i in range(len(x)):
                    outputs.append(x[i])
                return outputs
            
            all_sum = _to_list(all_sum)
            all_xsquare = _to_list(all_xsquare)
            xsums = AllReduce()(*all_sum)
            xsquares = AllReduce()(*all_xsquare)

            nGPUs = len(input)
            N = nGPUs * input[0].size(0)*input[0].size(2)*input[0].size(3)
            assert(N>1)
            xmean = xsums[0].data / N
            unbias_var = (xsquares[0].data - N * xmean * xmean) / (N-1) 
            # update running_mean and var
            self.running_mean = (1-self.momentum) * self.running_mean \
                + self.momentum * xmean
            self.running_var = (1-self.momentum) * self.running_var + \
                self.momentum * unbias_var
            # Broadcast the weight, bias, mean, std
            device_ids = list(range(torch.cuda.device_count()))
            weights = Broadcast(device_ids[:len(input)])(self.weight) 
            biases = Broadcast(device_ids[:len(input)])(self.bias)
            # parallel-apply
            results = {}
            def _worker_bn(i, x, xsum, xsquare, weight, bias, lock):
                var_input = _get_a_var(x)
                mean = xsum / N
                std  = (xsquare / N - mean * mean + self.eps).sqrt()
                try:
                    with torch.cuda.device_of(var_input):
                        B, C, H, W = x.size()
                        result = batchnormtrain(
                            x.view(B,C, -1), weight, bias, mean, 
                            std).view(B, C, H, W)
                    with lock: 
                        results[i] = result
                except Exception as e:
                    with lock:
                        results[i] = e
            threads = [threading.Thread(target=_worker_bn,
                                        args=(i, x, xsum, xsquare, weight, 
                                              bias, self.writelock)
                                       )
                        for i,( x, xsum, xsquare, weight, bias) in 
                        enumerate(zip(input, xsums, xsquares, 
                                      weights, biases))]
            for thread in threads:
                thread.start()
            for thread in threads:
                thread.join()
            outputs = []
            for i in range(len(results)):
                output = results[i]
                if isinstance(output, Exception):
                    raise output
                outputs.append(output)
            return outputs
        else:
            raise RuntimeError('unknown input type')
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440


def _get_a_var(obj):
    if isinstance(obj, Variable):
        return obj

    if isinstance(obj, list) or isinstance(obj, tuple):
        results = map(_get_a_var, obj)
        for result in results:
            if isinstance(result, Variable):
                return result
    if isinstance(obj, dict):
        results = map(_get_a_var, obj.items())
        for result in results:
            if isinstance(result, Variable):
                return result
    return None