"git@developer.sourcefind.cn:OpenDAS/dcnv3.git" did not exist on "bdd98bcb7f7b96bb9e8afae7e3fe4a077e5d437c"
Commit 27b3cb0b authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

Update hello-world.py

parent 57b58239
......@@ -5,10 +5,11 @@
# LICENSE file in the root directory of this source tree.
import torch
import torch.legacy.nn as nn
import sparseconvnet.legacy as scn
# Use the GPU if there is one, otherwise CPU
tensorType = 'torch.cuda.FloatTensor' if torch.cuda.is_available() else 'torch.FloatTensor'
dtype = 'torch.cuda.FloatTensor' if torch.cuda.is_available() else 'torch.FloatTensor'
model = scn.Sequential().add(
scn.SparseVggNet(2, 1,
......@@ -21,7 +22,7 @@ model = scn.Sequential().add(
scn.BatchNormReLU(32)
).add(
scn.SparseToDense(2)
).type(tensorType)
).type(dtype)
# output will be 10x10
inputSpatialSize = model.suggestInputSize(torch.LongTensor([10, 10]))
......@@ -50,7 +51,7 @@ for y, line in enumerate(msg):
input.precomputeMetadata(3)
model.evaluate()
input.type(tensorType)
input.type(dtype)
output = model.forward(input)
# Output is 1x32x10x10: our minibatch has 1 sample, the network has 32 output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment