batchNormalization.py 5.35 KB
Newer Older
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
1
2
3
4
5
6
# Copyright 2016-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

7
import sparseconvnet.SCN
8
from torch.autograd import Function
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
9
10
11
12
from torch.nn import Module, Parameter
from .utils import *
from .sparseConvNetTensor import SparseConvNetTensor

Benjamin Thomas Graham's avatar
tidy  
Benjamin Thomas Graham committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class BatchNormalization(Module):
    """
    Parameters:
    nPlanes : number of input planes
    eps : small number used to stabilise standard deviation calculation
    momentum : for calculating running average for testing (default 0.9)
    affine : only 'true' is supported at present (default 'true')
    noise : add multiplicative and additive noise during training if >0.
    leakiness : Apply activation def inplace: 0<=leakiness<=1.
    0 for ReLU, values in (0,1) for LeakyReLU, 1 for no activation def.
    """
    def __init__(
            self,
            nPlanes,
            eps=1e-4,
            momentum=0.9,
            affine=True,
            leakiness=1):
        Module.__init__(self)
        self.nPlanes = nPlanes
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.leakiness = leakiness
        self.register_buffer("runningMean", torch.Tensor(nPlanes).fill_(0))
        self.register_buffer("runningVar", torch.Tensor(nPlanes).fill_(1))
        if affine:
            self.weight = Parameter(torch.Tensor(nPlanes).fill_(1))
            self.bias = Parameter(torch.Tensor(nPlanes).fill_(0))

    def forward(self, input):
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
44
        assert input.features.nelement() == 0 or input.features.size(1) == self.nPlanes
Benjamin Thomas Graham's avatar
tidy  
Benjamin Thomas Graham committed
45
46
47
48
49
        output = SparseConvNetTensor()
        output.metadata = input.metadata
        output.spatial_size = input.spatial_size
        output.features = BatchNormalizationFunction.apply(
            input.features,
50
51
            optionalTensor(self, 'weight'),
            optionalTensor(self, 'bias'),
Benjamin Thomas Graham's avatar
tidy  
Benjamin Thomas Graham committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
            self.runningMean,
            self.runningVar,
            self.eps,
            self.momentum,
            self.training,
            self.leakiness)
        return output

    def input_spatial_size(self, out_size):
        return out_size

    def __repr__(self):
        s = 'BatchNorm(' + str(self.nPlanes) + ',eps=' + str(self.eps) + \
            ',momentum=' + str(self.momentum) + ',affine=' + str(self.affine)
        if self.leakiness > 0:
            s = s + ',leakiness=' + str(self.leakiness)
        s = s + ')'
        return s


class BatchNormReLU(BatchNormalization):
    def __init__(self, nPlanes, eps=1e-4, momentum=0.9):
        BatchNormalization.__init__(self, nPlanes, eps, momentum, True, 0)

    def __repr__(self):
        s = 'BatchNormReLU(' + str(self.nPlanes) + ',eps=' + str(self.eps) + \
            ',momentum=' + str(self.momentum) + ',affine=' + str(self.affine) + ')'
        return s


class BatchNormLeakyReLU(BatchNormalization):
    def __init__(self, nPlanes, eps=1e-4, momentum=0.9):
        BatchNormalization.__init__(self, nPlanes, eps, momentum, True, 0.333)

    def __repr__(self):
        s = 'BatchNormReLU(' + str(self.nPlanes) + ',eps=' + str(self.eps) + \
            ',momentum=' + str(self.momentum) + ',affine=' + str(self.affine) + ')'
        return s
90

Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
91
92
93
94
95
96
97
98
99
100
101
102
103
class BatchNormalizationFunction(Function):
    @staticmethod
    def forward(
            ctx,
            input_features,
            weight,
            bias,
            runningMean,
            runningVar,
            eps,
            momentum,
            train,
            leakiness):
104
105
106
107
108
109
        ctx.nPlanes = runningMean.shape[0]
        ctx.train = train
        ctx.leakiness = leakiness
        output_features = input_features.new()
        saveMean = input_features.new().resize_(ctx.nPlanes)
        saveInvStd = runningMean.clone().resize_(ctx.nPlanes)
110
        sparseconvnet.SCN.BatchNormalization_updateOutput(
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
111
            input_features,
112
113
114
115
116
            output_features,
            saveMean,
            saveInvStd,
            runningMean,
            runningVar,
117
118
            weight,
            bias,
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
119
120
121
122
            eps,
            momentum,
            ctx.train,
            ctx.leakiness)
123
124
125
126
127
128
129
130
131
        ctx.save_for_backward(input_features,
                              output_features,
                              weight,
                              bias,
                              runningMean,
                              runningVar,
                              saveMean,
                              saveInvStd)
        return output_features
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
132
133
134

    @staticmethod
    def backward(ctx, grad_output):
135
136
137
138
139
140
141
142
        input_features,\
            output_features,\
            weight,\
            bias,\
            runningMean,\
            runningVar,\
            saveMean,\
            saveInvStd = ctx.saved_tensors
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
143
        assert ctx.train
144
        grad_input = grad_output.new()
145
146
        grad_weight = torch.zeros_like(weight)
        grad_bias = torch.zeros_like(bias)
147
        sparseconvnet.SCN.BatchNormalization_backward(
148
149
150
151
152
153
154
155
            input_features,
            grad_input,
            output_features,
            grad_output.contiguous(),
            saveMean,
            saveInvStd,
            runningMean,
            runningVar,
156
157
158
159
            weight,
            bias,
            grad_weight,
            grad_bias,
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
160
            ctx.leakiness)
161
        return grad_input, optionalTensorReturn(grad_weight), optionalTensorReturn(grad_bias), None, None, None, None, None, None