hello-world.lua 2.78 KB
Newer Older
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
1
2
3
4
5
6
7
8
9
10
11
12
-- Copyright 2016-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the license found in the
-- LICENSE file in the root directory of this source tree.

--Train on the GPU if there is one, otherwise CPU
scn=require 'sparseconvnet'
tensorType = scn.cutorch and 'torch.CudaTensor' or 'torch.FloatTensor'


model = scn.Sequential()
13
  :add(scn.SparseVggNet(2,1,{ --dimension 2, 1 input plane
14
15
16
17
18
19
20
21
22
			  {'C', 8}, -- 3x3 VSC convolution, 8 output planes, batchnorm, ReLU
			  {'C', 8}, -- and another
			  {'MP', 3, 2}, --max pooling, size 3, stride 2
			  {'C', 16}, -- etc
			  {'C', 16},
			  {'MP', 3, 2},
			  {'C', 24},
			  {'C', 24},
			  {'MP', 3, 2}}))
23
24
25
26
  :add(scn.Convolution(2,24,32,3,1,false)) --an SC convolution on top
  :add(scn.BatchNormReLU(32))
  :add(scn.SparseToDense(2))
  :type(tensorType)
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
27
28

--[[
29
30
31
32
  To use the network we must create an scn.InputBatch with right dimensionality.
  If we want the output to have spatial size 10x10, we can find the appropriate
  input size, give that we uses three layers of MP3/2 max-pooling, and finish
  with a SC convoluton
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
]]

inputSpatialSize=model:suggestInputSize(torch.LongTensor{10,10}) --103x103
input=scn.InputBatch(2,inputSpatialSize)

--Now we build the input batch, sample by sample, and active site by active site.
msg={
  " O   O  OOO  O    O    OO     O       O   OO   OOO   O    OOO   ",
  " O   O  O    O    O   O  O    O       O  O  O  O  O  O    O  O  ",
  " OOOOO  OO   O    O   O  O    O   O   O  O  O  OOO   O    O   O ",
  " O   O  O    O    O   O  O     O O O O   O  O  O  O  O    O  O  ",
  " O   O  OOO  OOO  OOO  OO       O   O     OO   O  O  OOO  OOO   ",
}
input:addSample()
47
48
49
for y,line in ipairs(msg) do
  for x = 1,string.len(line) do
    if string.sub(line,x,x) == 'O' then
50
      local location = torch.LongTensor{y, x}
51
52
53
54
55
      local featureVector = torch.FloatTensor{1}
      input:setLocation(location,featureVector,0)
    end
  end
end
Ed Ng's avatar
Ed Ng committed
56

57
58
--We can also use setLocations
input:addSample()
Ed Ng's avatar
Ed Ng committed
59
60
local locations = {}
local featureVectors = {}
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
61
62
63
for y,line in ipairs(msg) do
  for x = 1,string.len(line) do
    if string.sub(line,x,x) == 'O' then
64
      table.insert(locations, {y, x})
Ed Ng's avatar
Ed Ng committed
65
      table.insert(featureVectors, {1})
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
66
67
68
    end
  end
end
Ed Ng's avatar
Ed Ng committed
69
70
71
72
73
input:setLocations(
  torch.LongTensor(locations),
  torch.FloatTensor(featureVectors),
  0)

Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
74
--[[
75
76
  Optional: allow metadata preprocessing to be done in batch preparation threads
  to improve GPU utilization.
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
77

78
79
80
  Parameter:
  3 if using MP3/2 or size-3 stride-2 convolutions for downsizeing,
  2 if using MP2
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
81
82
83
84
85
86
87
88
]]
input:precomputeMetadata(3)

model:evaluate()
input:type(tensorType)
output = model:forward(input)

--[[
89
90
  Output is 2x32x10x10: our minibatch has 2 samples, the network has 32 output
  feature planes, and 10x10 is the spatial size of the output.
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
91
92
]]
print(output:size(), output:type())