Commit d8b64558 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

UNet; fix example data loader permutation

parent 00891eb5
...@@ -17,6 +17,7 @@ from .sparseToDense import SparseToDense ...@@ -17,6 +17,7 @@ from .sparseToDense import SparseToDense
from .denseToSparse import DenseToSparse from .denseToSparse import DenseToSparse
from .tables import * from .tables import *
def SparseVggNet(dimension, nInputPlanes, layers): def SparseVggNet(dimension, nInputPlanes, layers):
""" """
VGG style nets VGG style nets
...@@ -36,7 +37,7 @@ def SparseVggNet(dimension, nInputPlanes, layers): ...@@ -36,7 +37,7 @@ def SparseVggNet(dimension, nInputPlanes, layers):
m.add(BatchNormReLU(nPlanes)) m.add(BatchNormReLU(nPlanes))
elif x[0] == 'C' and len(x) == 3: elif x[0] == 'C' and len(x) == 3:
m.add(ConcatTable() m.add(ConcatTable()
.add( .add(
SubmanifoldConvolution(dimension, nPlanes, x[1], 3, False) SubmanifoldConvolution(dimension, nPlanes, x[1], 3, False)
).add( ).add(
Sequential() Sequential()
...@@ -105,7 +106,7 @@ def SparseVggNet(dimension, nInputPlanes, layers): ...@@ -105,7 +106,7 @@ def SparseVggNet(dimension, nInputPlanes, layers):
.add(SubmanifoldConvolution(dimension, x[3], x[3], 3, False)) .add(SubmanifoldConvolution(dimension, x[3], x[3], 3, False))
.add(BatchNormReLU(x[3])) .add(BatchNormReLU(x[3]))
.add(Deconvolution(dimension, x[3], x[3], 3, 2, False)) .add(Deconvolution(dimension, x[3], x[3], 3, 2, False))
) )
.add(Sequential() .add(Sequential()
.add(Convolution(dimension, nPlanes, x[4], 3, 2, False)) .add(Convolution(dimension, nPlanes, x[4], 3, 2, False))
.add(BatchNormReLU(x[4])) .add(BatchNormReLU(x[4]))
...@@ -133,6 +134,7 @@ def SparseVggNet(dimension, nInputPlanes, layers): ...@@ -133,6 +134,7 @@ def SparseVggNet(dimension, nInputPlanes, layers):
m.add(BatchNormReLU(nPlanes)) m.add(BatchNormReLU(nPlanes))
return m return m
def SparseResNet(dimension, nInputPlanes, layers): def SparseResNet(dimension, nInputPlanes, layers):
""" """
pre-activated ResNet pre-activated ResNet
...@@ -202,3 +204,72 @@ def SparseResNet(dimension, nInputPlanes, layers): ...@@ -202,3 +204,72 @@ def SparseResNet(dimension, nInputPlanes, layers):
m.add(AddTable()) m.add(AddTable())
m.add(BatchNormReLU(nPlanes)) m.add(BatchNormReLU(nPlanes))
return m return m
def ResNetUNet(dimension, nPlanes, reps, depth=4):
"""
U-Net style network with ResNet-style blocks.
For voxel level prediction:
import sparseconvnet as scn
import torch.nn
class Model(nn.Module):
def __init__(self):
nn.Module.__init__(self)
self.sparseModel = scn.Sequential().add(
scn.ValidConvolution(3, nInputFeatures, 64, 3, False)).add(
scn.ResNetUNet(3, 64, 2, 4))
self.linear = nn.Linear(64, nClasses)
def forward(self,x):
x=self.sparseModel(x).features
x=self.linear(x)
return x
"""
def res(m, a, b):
m.add(ConcatTable()
.add(Identity() if a == b else NetworkInNetwork(a, b, False))
.add(Sequential()
.add(BatchNormReLU(a))
.add(SubmanifoldConvolution(dimension, a, b, 3, False))
.add(BatchNormReLU(b))
.add(SubmanifoldConvolution(dimension, b, b, 3, False))))\
.add(AddTable())
def v(depth, nPlanes):
m = Sequential()
if depth == 1:
for _ in range(reps):
res(m, nPlanes, nPlanes)
else:
m = Sequential()
for _ in range(reps):
res(m, nPlanes, nPlanes)
m.add(
ConcatTable() .add(
Identity()) .add(
Sequential() .add(
BatchNormReLU(nPlanes)) .add(
Convolution(
dimension,
nPlanes,
nPlanes,
2,
2,
False)) .add(
v(
depth - 1,
nPlanes)) .add(
BatchNormReLU(nPlanes)) .add(
Deconvolution(
dimension,
nPlanes,
nPlanes,
2,
2,
False))))
m.add(JoinTable())
for i in range(reps):
res(m, 2 * nPlanes if i == 0 else nPlanes, nPlanes)
return m
m = v(depth, nPlanes)
m.add(BatchNormReLU(nPlanes))
return m
...@@ -85,7 +85,7 @@ def train(spatial_size, Scale, precomputeStride): ...@@ -85,7 +85,7 @@ def train(spatial_size, Scale, precomputeStride):
tdi = scn.threadDatasetIterator(bd) tdi = scn.threadDatasetIterator(bd)
def iter(): def iter():
randperm = torch.randperm(len(d)) randperm.copy_(torch.randperm(len(d)))
return tdi() return tdi()
return iter return iter
...@@ -120,7 +120,7 @@ def val(spatial_size, Scale, precomputeStride): ...@@ -120,7 +120,7 @@ def val(spatial_size, Scale, precomputeStride):
tdi = scn.threadDatasetIterator(bd) tdi = scn.threadDatasetIterator(bd)
def iter(): def iter():
randperm = torch.randperm(len(d)) randperm.copy_(torch.randperm(len(d)))
return tdi() return tdi()
return iter return iter
......
...@@ -74,7 +74,7 @@ def train(spatial_size, Scale, precomputeStride): ...@@ -74,7 +74,7 @@ def train(spatial_size, Scale, precomputeStride):
tdi = scn.threadDatasetIterator(bd) tdi = scn.threadDatasetIterator(bd)
def iter(): def iter():
randperm = torch.randperm(len(d)) randperm.copy_(torch.randperm(len(d)))
return tdi() return tdi()
return iter return iter
...@@ -109,7 +109,7 @@ def val(spatial_size, Scale, precomputeStride): ...@@ -109,7 +109,7 @@ def val(spatial_size, Scale, precomputeStride):
tdi = scn.threadDatasetIterator(bd) tdi = scn.threadDatasetIterator(bd)
def iter(): def iter():
randperm = torch.randperm(len(d)) randperm.copy_(torch.randperm(len(d)))
return tdi() return tdi()
return iter return iter
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment