syncbn.py 2.46 KB
Newer Older
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
1
2
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
Zhang's avatar
v0.4.2  
Zhang committed
3
4
## Email: zhanghang0704@gmail.com
## Copyright (c) 2018
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
5
6
##
## This source code is licensed under the MIT-style license found in the
Hang Zhang's avatar
sync BN  
Hang Zhang committed
7
## LICENSE file in the root directory of this source tree
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
8
9
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

Zhang's avatar
Zhang committed
10
"""Synchronized Cross-GPU Batch Normalization functions"""
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
11
import torch
Zhang's avatar
Zhang committed
12
from torch.autograd import Variable, Function
Zhang's avatar
v0.4.2  
Zhang committed
13
from .. import lib
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
14

Zhang's avatar
v0.4.2  
Zhang committed
15
__all__ = ['sum_square', 'batchnormtrain']
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
16

Zhang's avatar
Zhang committed
17
18
19
20
21
def sum_square(input):
    r"""Calculate sum of elements and sum of squares for Batch Normalization"""
    return _sum_square.apply(input)


Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
22
class _sum_square(Function):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
23
    @staticmethod
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
24
25
    def forward(ctx, input):
        ctx.save_for_backward(input)
Zhang's avatar
v0.4.2  
Zhang committed
26
27
        if input.is_cuda:
            xsum, xsqusum = lib.gpu.sumsquare_forward(input)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
28
        else:
Zhang's avatar
v0.4.2  
Zhang committed
29
30
            raise NotImplemented
        return xsum, xsqusum
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
31

Hang Zhang's avatar
sync BN  
Hang Zhang committed
32
    @staticmethod
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
33
    def backward(ctx, gradSum, gradSquare):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
34
        input, = ctx.saved_variables
Zhang's avatar
v0.4.2  
Zhang committed
35
36
        if input.is_cuda:
            gradInput = lib.gpu.sumsquare_backward(input, gradSum, gradSquare)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
37
        else:
Zhang's avatar
v0.4.2  
Zhang committed
38
            raise NotImplemented
Zhang's avatar
Zhang committed
39
        return gradInput
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
40
41


Zhang's avatar
v0.4.2  
Zhang committed
42
43
44
45
46
47
class _batchnormtrain(Function):
    @staticmethod
    def forward(ctx, input, mean, std, gamma, beta):
        ctx.save_for_backward(input, mean, std, gamma, beta)
        if input.is_cuda:
            output = lib.gpu.batchnorm_forward(input, mean, std, gamma, beta)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
48
        else:
Zhang's avatar
v0.4.2  
Zhang committed
49
            raise NotImplemented
Hang Zhang's avatar
sync BN  
Hang Zhang committed
50
        return output
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
51

Zhang's avatar
v0.4.2  
Zhang committed
52
53
54
55
56
57
58
    @staticmethod
    def backward(ctx, gradOutput):
        input, mean, std, gamma, beta = ctx.saved_variables
        if gradOutput.is_cuda:
            gradInput, gradMean, gradStd, gradGamma, gradBeta = \
                lib.gpu.batchnorm_backward(gradOutput, input, mean,
                                           std, gamma, beta, True)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
59
        else:
Zhang's avatar
v0.4.2  
Zhang committed
60
61
            raise NotImplemented
        return gradInput, gradMean, gradStd, gradGamma, gradBeta
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
62
63


Zhang's avatar
v0.4.2  
Zhang committed
64
def batchnormtrain(input, mean, std, gamma, beta):
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    r"""Applies Batch Normalization over a 3d input that is seen as a
    mini-batch.

    .. _encoding.batchnormtrain:

    .. math::

        y = \frac{x - \mu[x]}{ \sqrt{var[x] + \epsilon}} * \gamma + \beta

    Shape:
        - Input: :math:`(N, C)` or :math:`(N, C, L)`
        - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)

    """
Zhang's avatar
v0.4.2  
Zhang committed
79
    return _batchnormtrain.apply(input, mean, std, gamma, beta)