syncbn.py 5.28 KB
Newer Older
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
1
2
3
4
5
6
7
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
Hang Zhang's avatar
sync BN  
Hang Zhang committed
8
## LICENSE file in the root directory of this source tree
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
9
10
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

Hang Zhang's avatar
Hang Zhang committed
11
"""Synchronized Cross-GPU Batch Normalization functions"""
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
12
import torch
Hang Zhang's avatar
Hang Zhang committed
13
from torch.autograd import Variable, Function
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
14
15
16
17
from .._ext import encoding_lib

__all__ = ['sum_square', 'batchnormtrain', 'batchnormeval']

Hang Zhang's avatar
Hang Zhang committed
18
19
20
21
22
def sum_square(input):
    r"""Calculate sum of elements and sum of squares for Batch Normalization"""
    return _sum_square.apply(input)


Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
23
class _sum_square(Function):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
24
    @staticmethod
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
25
26
    def forward(ctx, input):
        ctx.save_for_backward(input)
Hang Zhang's avatar
Hang Zhang committed
27
        C = input.size(1)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
28
        with torch.cuda.device_of(input):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
29
            xsum = input.new().resize_(C).zero_()
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
30
31
32
33
            xsquare = input.new().resize_(C).zero_()
        if isinstance(input, torch.cuda.FloatTensor):
            with torch.cuda.device_of(input):
                encoding_lib.Encoding_Float_sum_square_Forward(
Hang Zhang's avatar
Hang Zhang committed
34
                    input, xsum, xsquare)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
35
36
        elif isinstance(input, torch.cuda.DoubleTensor):
            with torch.cuda.device_of(input):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
37
                encoding_lib.Encoding_Double_sum_square_Forward(
Hang Zhang's avatar
Hang Zhang committed
38
                    input, xsum, xsquare)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
39
        else:
Hang Zhang's avatar
Hang Zhang committed
40
            raise RuntimeError('Unimplemented data type!', type(input))
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
41
42
        return xsum, xsquare

Hang Zhang's avatar
sync BN  
Hang Zhang committed
43
    @staticmethod
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
44
    def backward(ctx, gradSum, gradSquare):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
45
46
        input, = ctx.saved_variables
        with torch.cuda.device_of(input.data):
Hang Zhang's avatar
Hang Zhang committed
47
            gradInput = Variable(input.data.new().resize_as_(input.data).zero_())
Hang Zhang's avatar
sync BN  
Hang Zhang committed
48
49
        if isinstance(input.data, torch.cuda.FloatTensor):
            with torch.cuda.device_of(input.data):
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
50
                encoding_lib.Encoding_Float_sum_square_Backward(
Hang Zhang's avatar
Hang Zhang committed
51
                    gradInput.data, input.data, gradSum.data, gradSquare.data)
Hang Zhang's avatar
sync BN  
Hang Zhang committed
52
53
54
        elif isinstance(input.data, torch.cuda.DoubleTensor):
            with torch.cuda.device_of(input.data):
                encoding_lib.Encoding_Double_sum_square_Backward(
Hang Zhang's avatar
Hang Zhang committed
55
                    gradInput.data, input.data, gradSum.data, gradSquare.data)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
56
        else:
Hang Zhang's avatar
sync BN  
Hang Zhang committed
57
            raise RuntimeError('Unimplemented data type!')
Hang Zhang's avatar
Hang Zhang committed
58
        return gradInput
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
59
60


Hang Zhang's avatar
path  
Hang Zhang committed
61
class _batchnorm(Function):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
62
63
64
    def __init__(self, training=False):
        super(_batchnorm, self).__init__()
        self.training = training
Hang Zhang's avatar
path  
Hang Zhang committed
65

Hang Zhang's avatar
sync BN  
Hang Zhang committed
66
67
68
    def forward(self, input, gamma, beta, mean, std):
        self.save_for_backward(input, gamma, beta, mean, std)
        assert(input.dim() == 3)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
69
70
71
72
73
        with torch.cuda.device_of(input):
            invstd = 1.0 / std
            output = input.new().resize_as_(input)
        if isinstance(input, torch.cuda.FloatTensor):
            with torch.cuda.device_of(input):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
74
                encoding_lib.Encoding_Float_batchnorm_Forward(output, \
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
75
76
77
                    input, mean, invstd, gamma, beta)
        elif isinstance(input, torch.cuda.DoubleTensor):
            with torch.cuda.device_of(input):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
78
                encoding_lib.Encoding_Double_batchnorm_Forward(output, \
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
79
80
81
                    input, mean, invstd, gamma, beta)
        else:
            raise RuntimeError('Unimplemented data type!')
Hang Zhang's avatar
sync BN  
Hang Zhang committed
82
        return output
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
83

Hang Zhang's avatar
sync BN  
Hang Zhang committed
84
85
    def backward(self, gradOutput):
        input, gamma, beta, mean, std = self.saved_tensors
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
86
87
88
89
        invstd = 1.0 / std
        with torch.cuda.device_of(input):
            gradInput = gradOutput.new().resize_as_(input).zero_()
            gradGamma = gradOutput.new().resize_as_(gamma).zero_()
Hang Zhang's avatar
sync BN  
Hang Zhang committed
90
91
92
            gradBeta = gradOutput.new().resize_as_(beta).zero_()
            gradMean = gradOutput.new().resize_as_(mean).zero_()
            gradStd = gradOutput.new().resize_as_(std).zero_()
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
93
94
95
96

        if isinstance(input, torch.cuda.FloatTensor):
            with torch.cuda.device_of(input):
                encoding_lib.Encoding_Float_batchnorm_Backward(
Hang Zhang's avatar
sync BN  
Hang Zhang committed
97
                    gradOutput, input, gradInput, gradGamma, gradBeta,
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
98
                    mean, invstd, gamma, beta, gradMean, gradStd,
Hang Zhang's avatar
sync BN  
Hang Zhang committed
99
                    self.training)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
100
101
102
        elif isinstance(input, torch.cuda.DoubleTensor):
            with torch.cuda.device_of(input):
                encoding_lib.Encoding_Double_batchnorm_Backward(
Hang Zhang's avatar
sync BN  
Hang Zhang committed
103
                    gradOutput, input, gradInput, gradGamma, gradBeta,
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
104
                    mean, invstd, gamma, beta, gradMean, gradStd,
Hang Zhang's avatar
sync BN  
Hang Zhang committed
105
                    self.training)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        else:
            raise RuntimeError('Unimplemented data type!')
        return gradInput, gradGamma, gradBeta, gradMean, gradStd


def batchnormtrain(input, gamma, beta, mean, std):
    r"""Applies Batch Normalization over a 3d input that is seen as a
    mini-batch.

    .. _encoding.batchnormtrain:

    .. math::

        y = \frac{x - \mu[x]}{ \sqrt{var[x] + \epsilon}} * \gamma + \beta

    Shape:
        - Input: :math:`(N, C)` or :math:`(N, C, L)`
        - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)

    """
Hang Zhang's avatar
path  
Hang Zhang committed
126
    return _batchnorm(True)(input, gamma, beta, mean, std)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
127
128
129
130
131
132
133
134


def batchnormeval(input, gamma, beta, mean, std):
    r"""Applies Batch Normalization over a 3d input that is seen as a
    mini-batch.

    Please see encoding.batchnormtrain_
    """
Hang Zhang's avatar
path  
Hang Zhang committed
135
    return _batchnorm(False)(input, gamma, beta, mean, std)
Hang Zhang's avatar
Hang Zhang committed
136