DenseNetBlock.lua 3.29 KB
Newer Older
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
-- Copyright 2016-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the license found in the
-- LICENSE file in the root directory of this source tree.

return function(sparseconvnet)
  local C = sparseconvnet.C

  local DenseNetBlock, parent = torch.class(
    'sparseconvnet.DenseNetBlock', 'nn.Container', sparseconvnet)

  function DenseNetBlock:__init(dimension, nInputPlanes, nExtraLayers,
      growthRate)
    parent.__init(self)
    self.dimension=dimensions
    self.nInputPlanes=nInputPlanes
    self.nExtraLayers=nExtraLayers or 2
    self.growthRate=growthRate or 16
    assert(self.nExtraLayers>=1)
    self.nOutputPlanes=nInputPlanes+nExtraLayers*growthRate

    self.output={
      features=torch.Tensor(), --nActive x self.nOutputPlanes
    }

    --Module 1: Batchnorm the input into the start of self.output
    self:add(sparseconvnet.BatchNormalizationInTensor(nInputPlanes,nil,nil,0))
    self.modules[1].output=self.output
    self.gradInput=self.modules[1].gradInput

    for i = 1, nExtraLayers do
      local nFeatures = self.nInputPlanes + (i-1)*growthRate
      local nFeaturesB=4*growthRate
      --Modules 4*i-2
      self:add(sparseconvnet.AffineReluTrivialConvolution(nFeatures, nFeaturesB, true))
      --Module 4*i-1
      self:add(sparseconvnet.BatchNormalization(nFeaturesB,nil,nil,true,0))
      --Module 4*i
40
      self:add(sparseconvnet.SubmanifoldConvolution(dimension, nFeaturesB, growthRate,
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
          3, false))
      --Module 4*i+1
      self:add(sparseconvnet.BatchNormalizationInTensor(growthRate,nil,nil,
          self.nInputPlanes+(i-1)*growthRate))
      self.modules[4*i+1].output=self.output
    end

    self.filterSize = self.modules[4].filterSize
    self.filterStride = self.modules[4].filterStride
    self.filterSizeString = self.modules[4].filterSizeString
  end

  function DenseNetBlock:updateOutput(input)
    assert(input.features:size(2) == self.nInputPlanes)
    self.output.spatialSize = input.spatialSize
    self.output.metadata = input.metadata
    self.output.features:resize(input.features:size(1),self.nOutputPlanes)
    local i = input
    for m = 1, 4*self.nExtraLayers+1 do
      i=self.modules[m]:updateOutput(i)
    end
    return self.output
  end

  function DenseNetBlock:backward(input, gradOutput)
    local g = gradOutput
    for i = 1, self.nExtraLayers do
      self.modules[4*i-2].gradInput=gradOutput
    end
    for m=4*self.nExtraLayers+1,2,-1 do
      g = self.modules[m]:backward(self.modules[m-1].output,g)
    end
    self.modules[1]:backward(input,g)
    return self.gradInput
  end

  function DenseNetBlock:type(type,tensorCache)
    self._type=type
    self.output.features=self.output.features:type(type)
    for _,x in pairs(self.modules) do
      x:type(type)
    end
  end

  function DenseNetBlock:__tostring()
    local s = 'DenseNetBlock('.. self.nInputPlanes .. '->' ..
    self.nInputPlanes .. '+' .. self.nExtraLayers .. '*' ..
    self.growthRate .. '=' .. self.nOutputPlanes .. ')'
    return s
  end

  function DenseNetBlock:clearState()
    for _,m in ipairs(self.modules) do
      m:clearState()
    end
    self.output={
      features=self.output.features:set(),
      nPlanes=self.nOutputPlanes,
      dimension=self.dimension
    }
  end

  function DenseNetBlock:suggestInputSize(nOut)
    return nOut
  end
end