BatchNormalization.lua 4.15 KB
Newer Older
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
-- Copyright 2016-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the license found in the
-- LICENSE file in the root directory of this source tree.

--[[
Parameters:
nPlanes : number of input planes
eps : small number used to stabilise standard deviation calculation
momentum : for calculating running average for testing (default 0.9)
affine : only 'true' is supported at present (default 'true')
noise : add multiplicative and additive noise during training if >0.
leakiness : Apply activation function inplace: 0<=leakiness<=1.
0 for ReLU, values in (0,1) for LeakyReLU, 1 for no activation function.
]]
return function(sparseconvnet)
  local C = sparseconvnet.C
  local BN,parent = torch.class(
    'sparseconvnet.BatchNormalization', 'nn.Module', sparseconvnet)

  function BN:__init(nPlanes, eps, momentum, affine, leakiness)
    parent.__init(self)
    assert(nPlanes%4==0)
    self.nPlanes=nPlanes
    self.leakiness=leakiness or 1
    if affine ~= nil then
      assert(type(affine) == 'boolean', 'affine has to be true/false')
      self.affine = affine
    else
      self.affine = true
    end
    self.eps = eps or 1e-5
    self.saveMean = torch.Tensor(nPlanes)
    self.saveInvStd = torch.Tensor(nPlanes)
    self.momentum = momentum or 0.9
    self.runningMean = torch.Tensor(nPlanes)
    self.runningVar = torch.Tensor(nPlanes)
    if self.affine then
      self.weight = torch.Tensor(nPlanes)
      self.bias = torch.Tensor(nPlanes)
      self.gradWeight = torch.Tensor(nPlanes)
      self.gradBias = torch.Tensor(nPlanes)
    end
    self.output = {
      features = torch.Tensor()
    }
    self.gradInput = {
      features = torch.Tensor()
    }
    self:reset()
  end
  function BN:reset()
    if self.affine then
      self.weight:fill(1)
      self.bias:zero()
    end
    self.runningMean:zero()
    self.runningVar:fill(1)
    self.saveMean:zero()
    self.saveInvStd:fill(1)
  end

  function BN:updateOutput(input)
    assert(input.features:size(2)==self.nPlanes)
    self.output.metadata = input.metadata
    self.output.spatialSize=input.spatialSize
    C.typedFn(self._type,'BatchNormalization_updateOutput')(
      input.features:cdata(),
      self.output.features:cdata(),
      self.saveMean:cdata(),
      self.saveInvStd:cdata(),
      self.runningMean:cdata(),
      self.runningVar:cdata(),
      self.weight and self.weight:cdata(),
      self.bias and self.bias:cdata(),
      self.eps,
      self.momentum,
      self.train,
      self.leakiness)
    return self.output
  end

  function BN:backward(input, gradOutput)
    assert(self.train)
    C.typedFn(self._type,'BatchNormalization_backward')(
      input.features:cdata(),
      self.gradInput.features:cdata(),
      self.output.features:cdata(),
      gradOutput.features:cdata(),
      self.saveMean:cdata(),
      self.saveInvStd:cdata(),
      self.runningMean:cdata(),
      self.runningVar:cdata(),
      self.weight and self.weight:cdata(),
      self.bias and self.bias:cdata(),
      self.gradWeight and self.gradWeight:cdata(),
      self.gradBias and self.gradBias:cdata(),
      self.leakiness)
    return self.gradInput
  end

  function BN:updateGradInput(input, gradOutput)
    assert(false) --just call backward
  end

  function BN:accGradParameters(input, gradOutput, scale)
    assert(false) --just call backward
  end

  function BN:__tostring()
    local l
    if self.leakiness==0 then
      l=',ReLU'
    elseif self.leakiness==1/3 then
      l=',LeakyReLU(0.333..)'
    elseif self.leakiness<1 then
      l=',LeakyReLU('..self.leakiness..')'
    else
      l=''
    end
    local s = 'BatchNormalization(' ..
    'nPlanes=' .. self.nPlanes..',' ..
    'eps=' .. self.eps .. ',' ..
    'momentum=' .. self.momentum .. l .. ')'
    return s
  end

  function BN:clearState()
    self.output={features=self.output.features:set()}
    self.gradInput={features=self.gradInput.features:set()}
  end

  function BN:suggestInputSize(nOut)
    return nOut
  end

  local BN,parent = torch.class('sparseconvnet.BatchNormReLU',
    'sparseconvnet.BatchNormalization', sparseconvnet)
  function BN:__init(nPlanes, eps, momentum)
    parent.__init(self, nPlanes, eps, momentum, true, 0)
  end
end