init.lua 3.32 KB
Newer Older
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
-- Copyright 2016-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the license found in the
-- LICENSE file in the root directory of this source tree.

local sparseconvnet = {}
sparseconvnet.nn=require 'nn'
sparseconvnet.optim=require 'optim'
sparseconvnet.cutorch = pcall(require , 'cutorch')
sparseconvnet.cunn = pcall(require , 'cunn')
sparseconvnet.cudnn = pcall(require, 'cudnn')
for _,module in ipairs({
    'sparseconvnet.C',
    'sparseconvnet.AffineReluTrivialConvolution',
    'sparseconvnet.AveragePooling',
    'sparseconvnet.BatchNormalization',
    'sparseconvnet.BatchNormalizationInTensor',
    'sparseconvnet.BatchwiseDropout',
    'sparseconvnet.CAddTable',
    'sparseconvnet.ClassificationTrainValidate',
    'sparseconvnet.ConcatTable',
    'sparseconvnet.Convolution',
    'sparseconvnet.DataLoader',
    'sparseconvnet.Deconvolution',
    'sparseconvnet.DenseNetBlock',
    'sparseconvnet.Identity',
    'sparseconvnet.InputBatch',
    'sparseconvnet.JoinTable',
    'sparseconvnet.LeakyReLU',
    'sparseconvnet.MaxPooling',
    'sparseconvnet.Metadata',
    'sparseconvnet.NetworkArchitectures',
    'sparseconvnet.NetworkInNetwork',
    'sparseconvnet.ReLU',
    'sparseconvnet.Sequential',
    'sparseconvnet.SparseToDense',
    'sparseconvnet.ValidConvolution',
  }) do
  require(module)(sparseconvnet)
end

function sparseconvnet.initializeDenseModel(model)
  --Following https://github.com/facebook/fb.resnet.torch/blob/master/models/preresnet.lua
  local function ConvInit(name)
    for k,v in pairs(model:findModules(name)) do
      local n = v.kW*v.kH*v.nInputPlane --use nInputPlane instead of nOutputPlane
      v.weight:normal(0,math.sqrt(2/n))
      if cudnn.version >= 4000 then
        v.bias = nil
        v.gradBias = nil
      else
        v.bias:zero()
      end
    end
  end
  local function BNInit(name)
    for k,v in pairs(model:findModules(name)) do
      v.weight:fill(1)
      v.bias:zero()
    end
  end
  ConvInit('cudnn.SpatialConvolution')
  ConvInit('nn.SpatialConvolution')
  BNInit('fbnn.SpatialBatchNormalization')
  BNInit('cudnn.SpatialBatchNormalization')
  BNInit('nn.SpatialBatchNormalization')
  for k,v in pairs(model:findModules('nn.Linear')) do
    v.bias:zero()
  end
  return model
end
function sparseconvnet.toLongTensor(x,dimension)
  if type(x) == 'number' then
    return torch.LongTensor(dimension):fill(x)
  elseif type(x) == 'table' then
    return torch.LongTensor(x)
  else
    assert(x:size(1) == dimension)
    return x
  end
end

function sparseconvnet.toDoubleTensor(x,dimension)
  if type(x) == 'number' then
    return torch.DoubleTensor(dimension):fill(x)
  elseif type(x) == 'table' then
    return torch.DoubleTensor(x)
  else
    return x
  end
end

function sparseconvnet.shareShared(mod)
  mod.shared = mod.shared or
  {forwardPassMultiplyAddCount=0, forwardPassHiddenStates=0}
  if mod._type:sub(7,10) == 'Cuda' then --only needed on the GPU
    if sparseconvnet.ruleBookBits==64 then
      mod.shared.rulesBuffer = torch.CudaLongTensor()
    else
      mod.shared.rulesBuffer = torch.CudaIntTensor()
    end
  else
    mod.shared.rulesBuffer = nil
  end
  local function walk(module)
    module.shared=mod.shared
    if module.modules then
      for _,module in ipairs(module.modules) do
        walk(module)
      end
    end
  end
  walk(mod)
end

return sparseconvnet