BatchNormalizationInTensor.lua 3.18 KB
Newer Older
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
-- Copyright 2016-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the license found in the
-- LICENSE file in the root directory of this source tree.

--[[
Parameters:
nPlanes : number of input planes
eps : small number used to stabilise standard deviation calculation
momentum : for calculating running average for testing (default 0.9)
affine : only 'true' is supported at present (default 'true')
noise : add multiplicative and additive noise during training if >0.
leakiness : Apply activation function inplace: 0<=leakiness<=1.
0 for ReLU, values in (0,1) for LeakyReLU, 1 for no activation function.
]]
return function(sparseconvnet)
  local C = sparseconvnet.C
  local BN,parent = torch.class(
    'sparseconvnet.BatchNormalizationInTensor', 'sparseconvnet.BatchNormalization', sparseconvnet)

  function BN:__init(nPlanes, eps, momentum, outputColumnOffset)
    parent.__init(self,nPlanes,eps,momentum, false, 1)
    self.outputColumnOffset=outputColumnOffset
  end

  function BN:updateOutput(input)
    local o = self.output.features:narrow(2,1+self.outputColumnOffset,self.nPlanes)
    self.output.metadata = input.metadata
    self.output.spatialSize=input.spatialSize
    C.typedFn(self._type,'BatchNormalizationInTensor_updateOutput')(
      input.features:cdata(),
      o:cdata(),
      self.saveMean:cdata(),
      self.saveInvStd:cdata(),
      self.runningMean:cdata(),
      self.runningVar:cdata(),
      self.weight and self.weight:cdata(),
      self.bias and self.bias:cdata(),
      self.eps,
      self.momentum,
      self.train,
      self.leakiness)
    return self.output
  end

  function BN:backward(input, gradOutput)
    assert(self.train)
    local o = self.output.features:narrow(2,1+self.outputColumnOffset,self.nPlanes)
    local d_o = gradOutput.features:narrow(2,1+self.outputColumnOffset,self.nPlanes)
    C.typedFn(self._type,'BatchNormalization_backward')(
      input.features:cdata(),
      self.gradInput.features:cdata(),
      o:cdata(),
      d_o:cdata(),
      self.saveMean:cdata(),
      self.saveInvStd:cdata(),
      self.runningMean:cdata(),
      self.runningVar:cdata(),
      self.weight and self.weight:cdata(),
      self.bias and self.bias:cdata(),
      self.gradWeight and self.gradWeight:cdata(),
      self.gradBias and self.gradBias:cdata(),
      self.leakiness)
    return self.gradInput
  end

  function BN:updateGradInput(input, gradOutput)
    assert(false) --just call backward
  end

  function BN:accGradParameters(input, gradOutput, scale)
    assert(false) --just call backward
  end

  function BN:__tostring()
    local l
    if self.leakiness==0 then
      l=',ReLU'
    elseif self.leakiness==1/3 then
      l=',LeakyReLU(0.333..)'
    elseif self.leakiness<1 then
      l=',LeakyReLU('..self.leakiness..')'
    else
      l=''
    end
    local s = 'BatchNormalizationInTensor(' ..
    'nPlanes=' .. self.nPlanes..',' ..
    'eps=' .. self.eps .. ',' ..
    'momentum=' .. self.momentum .. l .. ')'
    return s
  end

  function BN:clearState()
    self.output={features=self.output.features:set()}
    self.gradInput={features=self.gradInput.features:set()}
  end

  function BN:suggestInputSize(nOut)
    return nOut
  end
end