"vscode:/vscode.git/clone" did not exist on "f250aa415e7959769bc840700379e20c2bdbe5ac"
Commit 27b3cb0b authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

Update hello-world.py

parent 57b58239
...@@ -5,10 +5,11 @@ ...@@ -5,10 +5,11 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch
import torch.legacy.nn as nn
import sparseconvnet.legacy as scn import sparseconvnet.legacy as scn
# Use the GPU if there is one, otherwise CPU # Use the GPU if there is one, otherwise CPU
tensorType = 'torch.cuda.FloatTensor' if torch.cuda.is_available() else 'torch.FloatTensor' dtype = 'torch.cuda.FloatTensor' if torch.cuda.is_available() else 'torch.FloatTensor'
model = scn.Sequential().add( model = scn.Sequential().add(
scn.SparseVggNet(2, 1, scn.SparseVggNet(2, 1,
...@@ -21,7 +22,7 @@ model = scn.Sequential().add( ...@@ -21,7 +22,7 @@ model = scn.Sequential().add(
scn.BatchNormReLU(32) scn.BatchNormReLU(32)
).add( ).add(
scn.SparseToDense(2) scn.SparseToDense(2)
).type(tensorType) ).type(dtype)
# output will be 10x10 # output will be 10x10
inputSpatialSize = model.suggestInputSize(torch.LongTensor([10, 10])) inputSpatialSize = model.suggestInputSize(torch.LongTensor([10, 10]))
...@@ -50,7 +51,7 @@ for y, line in enumerate(msg): ...@@ -50,7 +51,7 @@ for y, line in enumerate(msg):
input.precomputeMetadata(3) input.precomputeMetadata(3)
model.evaluate() model.evaluate()
input.type(tensorType) input.type(dtype)
output = model.forward(input) output = model.forward(input)
# Output is 1x32x10x10: our minibatch has 1 sample, the network has 32 output # Output is 1x32x10x10: our minibatch has 1 sample, the network has 32 output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment