hello-world.py 2.33 KB
Newer Older
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
1
2
3
4
5
6
7
# Copyright 2016-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
8
import sparseconvnet as scn
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
9
10

# Use the GPU if there is one, otherwise CPU
11
use_cuda = torch.cuda.is_available()
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
12
13
14

model = scn.Sequential().add(
    scn.SparseVggNet(2, 1,
15
                     [['C', 8], ['C', 8], ['MP', 3, 2],
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
16
17
                      ['C', 16], ['C', 16], ['MP', 3, 2],
                      ['C', 24], ['C', 24], ['MP', 3, 2]])
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
18
).add(
19
    scn.SubmanifoldConvolution(2, 24, 32, 3, False)
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
20
21
22
).add(
    scn.BatchNormReLU(32)
).add(
23
    scn.SparseToDense(2, 32)
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
24
)
25
if use_cuda:
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
26
    model.cuda()
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
27
28

# output will be 10x10
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
29
inputSpatialSize = model.input_spatial_size(torch.LongTensor([10, 10]))
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
30
31
32
33
34
35
36
37
input = scn.InputBatch(2, inputSpatialSize)

msg = [
    " X   X  XXX  X    X    XX     X       X   XX   XXX   X    XXX   ",
    " X   X  X    X    X   X  X    X       X  X  X  X  X  X    X  X  ",
    " XXXXX  XX   X    X   X  X    X   X   X  X  X  XXX   X    X   X ",
    " X   X  X    X    X   X  X     X X X X   X  X  X  X  X    X  X  ",
    " X   X  XXX  XXX  XXX  XX       X   X     XX   X  X  XXX  XXX   "]
38

Benjamin Thomas Graham's avatar
tidy  
Benjamin Thomas Graham committed
39
40
# Add a sample using set_location
input.add_sample()
41
42
for y, line in enumerate(msg):
    for x, c in enumerate(line):
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
43
44
45
        if c == 'X':
            location = torch.LongTensor([y, x])
            featureVector = torch.FloatTensor([1])
Benjamin Thomas Graham's avatar
tidy  
Benjamin Thomas Graham committed
46
            input.set_location(location, featureVector, 0)
Ed Ng's avatar
Ed Ng committed
47

Benjamin Thomas Graham's avatar
tidy  
Benjamin Thomas Graham committed
48
49
# Add a sample using set_locations
input.add_sample()
Ed Ng's avatar
Ed Ng committed
50
51
locations = []
features = []
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
52
53
for y, line in enumerate(msg):
    for x, c in enumerate(line):
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
54
55
56
        if c == 'X':
            locations.append([y, x])
            features.append([1])
Ed Ng's avatar
Ed Ng committed
57
58
locations = torch.LongTensor(locations)
features = torch.FloatTensor(features)
Benjamin Thomas Graham's avatar
tidy  
Benjamin Thomas Graham committed
59
input.set_locations(locations, features, 0)
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
60
61
62
63
64
65
66
67
68

# Optional: allow metadata preprocessing to be done in batch preparation threads
# to improve GPU utilization.
#
# Parameter:
#    3 if using MP3/2 pooling or C3/2 convolutions for downsizing,
#    2 if using MP2 pooling for downsizing.
input.precomputeMetadata(3)

Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
69
model.train()
70
if use_cuda:
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
71
    input.cuda()
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
72
73
output = model.forward(input)

74
# Output is 2x32x10x10: our minibatch has 2 samples, the network has 32 output
Benjamin Thomas Graham's avatar
Benjamin Thomas Graham committed
75
# feature planes, and 10x10 is the spatial size of the output.
76
print(output.shape, output.type())