ConcatTable.lua 1.57 KB
Newer Older
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
-- Copyright 2016-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the license found in the
-- LICENSE file in the root directory of this source tree.

return function(sparseconvnet)
  local ConcatTable, parent = torch.class(
    'sparseconvnet.ConcatTable', 'nn.ConcatTable', sparseconvnet)

  function ConcatTable:__init()
    parent.__init(self)
    self.modules={}
    self.output={}
    self.gradInput={
      features=torch.Tensor()
    }
    sparseconvnet.shareShared(self)
  end

  function ConcatTable:add(module)
    table.insert(self.modules,module)
    sparseconvnet.shareShared(self)
    return self
  end

  function ConcatTable:updateOutput(input)
    for i = 1,#self.modules do
      self.output[i]=self.modules[i]:forward(input)
    end
    for i = #self.modules+1,#self.output do
      self.output[i]=nil
    end
    return self.output
  end

  function ConcatTable:backward(input, gradOutput)
    local gradInputs={}
    for i = 1,#self.modules do
      gradInputs[i]=self.modules[i]:backward(input,gradOutput[i],scale)
    end
    self.gradInput.features:resizeAs(
      gradInputs[1].features):copy(gradInputs[1].features)
    for i=2,#self.modules do
      self.gradInput.features:add(gradInputs[i].features)
    end
    return self.gradInput
  end

  function ConcatTable:clearState()
    for _,m in ipairs(self.modules) do
      m:clearState()
    end
    self.output={}
    self.gradInput={features=self.gradInput.features:set()}
  end

  function ConcatTable:suggestInputSize(nOut)
    return self.modules[1]:suggestInputSize(nOut)
  end
end