hello-world.py 2.06 KB
Newer Older
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
1
2
3
4
5
6
7
# Copyright 2016-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
8
import torch.legacy.nn as nn
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
9
10
11
import sparseconvnet.legacy as scn

# Use the GPU if there is one, otherwise CPU
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
12
dtype = 'torch.cuda.FloatTensor' if torch.cuda.is_available() else 'torch.FloatTensor'
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
13
14
15
16
17
18
19
20
21
22
23
24

model = scn.Sequential().add(
    scn.SparseVggNet(2, 1,
                     [['C', 8], ['C', 8], ['MP', 3, 2],
                      ['C', 16], ['C', 16], ['MP', 3, 2],
                         ['C', 24], ['C', 24], ['MP', 3, 2]])
).add(
    scn.ValidConvolution(2, 24, 32, 3, False)
).add(
    scn.BatchNormReLU(32)
).add(
    scn.SparseToDense(2)
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
25
).type(dtype)
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
26
27
28
29
30
31
32
33
34
35
36
37

# output will be 10x10
inputSpatialSize = model.suggestInputSize(torch.LongTensor([10, 10]))
input = scn.InputBatch(2, inputSpatialSize)

msg = [
    " X   X  XXX  X    X    XX     X       X   XX   XXX   X    XXX   ",
    " X   X  X    X    X   X  X    X       X  X  X  X  X  X    X  X  ",
    " XXXXX  XX   X    X   X  X    X   X   X  X  X  XXX   X    X   X ",
    " X   X  X    X    X   X  X     X X X X   X  X  X  X  X    X  X  ",
    " X   X  XXX  XXX  XXX  XX       X   X     XX   X  X  XXX  XXX   "]
input.addSample()
Ed Ng's avatar
Ed Ng committed
38
39
40
41

locations = []
features = []

Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
42
43
44
for y, line in enumerate(msg):
    for x, c in enumerate(line):
        if c == 'X':
Ed Ng's avatar
Ed Ng committed
45
46
47
48
49
50
51
            locations.append([x,y])
            features.append([1])

locations = torch.LongTensor(locations)
features = torch.FloatTensor(features)

input.setLocations(locations, features, 0)
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
52
53
54
55
56
57
58
59
60
61

# Optional: allow metadata preprocessing to be done in batch preparation threads
# to improve GPU utilization.
#
# Parameter:
#    3 if using MP3/2 pooling or C3/2 convolutions for downsizing,
#    2 if using MP2 pooling for downsizing.
input.precomputeMetadata(3)

model.evaluate()
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
62
input.type(dtype)
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
63
64
65
66
67
output = model.forward(input)

# Output is 1x32x10x10: our minibatch has 1 sample, the network has 32 output
# feature planes, and 10x10 is the spatial size of the output.
print(output.size(), output.type())