Commit d8b64558 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

UNet; fix example data loader permutation

parent 00891eb5
......@@ -17,6 +17,7 @@ from .sparseToDense import SparseToDense
from .denseToSparse import DenseToSparse
from .tables import *
def SparseVggNet(dimension, nInputPlanes, layers):
"""
VGG style nets
......@@ -36,7 +37,7 @@ def SparseVggNet(dimension, nInputPlanes, layers):
m.add(BatchNormReLU(nPlanes))
elif x[0] == 'C' and len(x) == 3:
m.add(ConcatTable()
.add(
.add(
SubmanifoldConvolution(dimension, nPlanes, x[1], 3, False)
).add(
Sequential()
......@@ -105,7 +106,7 @@ def SparseVggNet(dimension, nInputPlanes, layers):
.add(SubmanifoldConvolution(dimension, x[3], x[3], 3, False))
.add(BatchNormReLU(x[3]))
.add(Deconvolution(dimension, x[3], x[3], 3, 2, False))
)
)
.add(Sequential()
.add(Convolution(dimension, nPlanes, x[4], 3, 2, False))
.add(BatchNormReLU(x[4]))
......@@ -133,6 +134,7 @@ def SparseVggNet(dimension, nInputPlanes, layers):
m.add(BatchNormReLU(nPlanes))
return m
def SparseResNet(dimension, nInputPlanes, layers):
"""
pre-activated ResNet
......@@ -202,3 +204,72 @@ def SparseResNet(dimension, nInputPlanes, layers):
m.add(AddTable())
m.add(BatchNormReLU(nPlanes))
return m
def ResNetUNet(dimension, nPlanes, reps, depth=4):
"""
U-Net style network with ResNet-style blocks.
For voxel level prediction:
import sparseconvnet as scn
import torch.nn
class Model(nn.Module):
def __init__(self):
nn.Module.__init__(self)
self.sparseModel = scn.Sequential().add(
scn.ValidConvolution(3, nInputFeatures, 64, 3, False)).add(
scn.ResNetUNet(3, 64, 2, 4))
self.linear = nn.Linear(64, nClasses)
def forward(self,x):
x=self.sparseModel(x).features
x=self.linear(x)
return x
"""
def res(m, a, b):
m.add(ConcatTable()
.add(Identity() if a == b else NetworkInNetwork(a, b, False))
.add(Sequential()
.add(BatchNormReLU(a))
.add(SubmanifoldConvolution(dimension, a, b, 3, False))
.add(BatchNormReLU(b))
.add(SubmanifoldConvolution(dimension, b, b, 3, False))))\
.add(AddTable())
def v(depth, nPlanes):
m = Sequential()
if depth == 1:
for _ in range(reps):
res(m, nPlanes, nPlanes)
else:
m = Sequential()
for _ in range(reps):
res(m, nPlanes, nPlanes)
m.add(
ConcatTable() .add(
Identity()) .add(
Sequential() .add(
BatchNormReLU(nPlanes)) .add(
Convolution(
dimension,
nPlanes,
nPlanes,
2,
2,
False)) .add(
v(
depth - 1,
nPlanes)) .add(
BatchNormReLU(nPlanes)) .add(
Deconvolution(
dimension,
nPlanes,
nPlanes,
2,
2,
False))))
m.add(JoinTable())
for i in range(reps):
res(m, 2 * nPlanes if i == 0 else nPlanes, nPlanes)
return m
m = v(depth, nPlanes)
m.add(BatchNormReLU(nPlanes))
return m
......@@ -85,7 +85,7 @@ def train(spatial_size, Scale, precomputeStride):
tdi = scn.threadDatasetIterator(bd)
def iter():
randperm = torch.randperm(len(d))
randperm.copy_(torch.randperm(len(d)))
return tdi()
return iter
......@@ -120,7 +120,7 @@ def val(spatial_size, Scale, precomputeStride):
tdi = scn.threadDatasetIterator(bd)
def iter():
randperm = torch.randperm(len(d))
randperm.copy_(torch.randperm(len(d)))
return tdi()
return iter
......
......@@ -74,7 +74,7 @@ def train(spatial_size, Scale, precomputeStride):
tdi = scn.threadDatasetIterator(bd)
def iter():
randperm = torch.randperm(len(d))
randperm.copy_(torch.randperm(len(d)))
return tdi()
return iter
......@@ -109,7 +109,7 @@ def val(spatial_size, Scale, precomputeStride):
tdi = scn.threadDatasetIterator(bd)
def iter():
randperm = torch.randperm(len(d))
randperm.copy_(torch.randperm(len(d)))
return tdi()
return iter
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment