Commit 8c706edd authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

Leaky UNets

parent 55d55a6a
...@@ -80,15 +80,12 @@ class BatchNormReLU(BatchNormalization): ...@@ -80,15 +80,12 @@ class BatchNormReLU(BatchNormalization):
class BatchNormLeakyReLU(BatchNormalization): class BatchNormLeakyReLU(BatchNormalization):
def __init__(self, nPlanes, eps=1e-4, momentum=0.9): def __init__(self, nPlanes, eps=1e-4, momentum=0.9, leakiness=0.333):
BatchNormalization.__init__(self, nPlanes, eps, momentum, True, 0.333) BatchNormalization.__init__(self, nPlanes, eps, momentum, True, leakiness)
def __repr__(self): def __repr__(self):
s = 'BatchLeakyNorm(' + str(self.nPlanes) + ',eps=' + str(self.eps) + \ s = 'BatchNormLeakyReLU(' + str(self.nPlanes) + ',eps=' + str(self.eps) + \
',momentum=' + str(self.momentum) + ',affine=' + str(self.affine) ',momentum=' + str(self.momentum) + ',affine=' + str(self.affine) + ',leakiness='+str(self.leakiness)+')'
if self.leakiness > 0:
s = s + ',leakiness=' + str(self.leakiness)
s = s + ')'
return s return s
class BatchNormalizationFunction(Function): class BatchNormalizationFunction(Function):
......
...@@ -200,7 +200,7 @@ def SparseResNet(dimension, nInputPlanes, layers): ...@@ -200,7 +200,7 @@ def SparseResNet(dimension, nInputPlanes, layers):
return m return m
def UNet(dimension, reps, nPlanes, residual_blocks=False, downsample=[2, 2]): def UNet(dimension, reps, nPlanes, residual_blocks=False, downsample=[2, 2], leakiness=0):
""" """
U-Net style network with VGG or ResNet-style blocks. U-Net style network with VGG or ResNet-style blocks.
For voxel level prediction: For voxel level prediction:
...@@ -223,14 +223,14 @@ def UNet(dimension, reps, nPlanes, residual_blocks=False, downsample=[2, 2]): ...@@ -223,14 +223,14 @@ def UNet(dimension, reps, nPlanes, residual_blocks=False, downsample=[2, 2]):
m.add(scn.ConcatTable() m.add(scn.ConcatTable()
.add(scn.Identity() if a == b else scn.NetworkInNetwork(a, b, False)) .add(scn.Identity() if a == b else scn.NetworkInNetwork(a, b, False))
.add(scn.Sequential() .add(scn.Sequential()
.add(scn.BatchNormReLU(a)) .add(scn.BatchNormLeakyReLU(a,leakiness=leakiness))
.add(scn.SubmanifoldConvolution(dimension, a, b, 3, False)) .add(scn.SubmanifoldConvolution(dimension, a, b, 3, False))
.add(scn.BatchNormReLU(b)) .add(scn.BatchNormLeakyReLU(b,leakiness=leakiness))
.add(scn.SubmanifoldConvolution(dimension, b, b, 3, False))) .add(scn.SubmanifoldConvolution(dimension, b, b, 3, False)))
).add(scn.AddTable()) ).add(scn.AddTable())
else: #VGG style blocks else: #VGG style blocks
m.add(scn.Sequential() m.add(scn.Sequential()
.add(scn.BatchNormReLU(a)) .add(scn.BatchNormLeakyReLU(a,leakiness=leakiness))
.add(scn.SubmanifoldConvolution(dimension, a, b, 3, False))) .add(scn.SubmanifoldConvolution(dimension, a, b, 3, False)))
def U(nPlanes): #Recursive function def U(nPlanes): #Recursive function
m = scn.Sequential() m = scn.Sequential()
...@@ -245,11 +245,11 @@ def UNet(dimension, reps, nPlanes, residual_blocks=False, downsample=[2, 2]): ...@@ -245,11 +245,11 @@ def UNet(dimension, reps, nPlanes, residual_blocks=False, downsample=[2, 2]):
scn.ConcatTable().add( scn.ConcatTable().add(
scn.Identity()).add( scn.Identity()).add(
scn.Sequential().add( scn.Sequential().add(
scn.BatchNormReLU(nPlanes[0])).add( scn.BatchNormLeakyReLU(nPlanes[0],leakiness=leakiness)).add(
scn.Convolution(dimension, nPlanes[0], nPlanes[1], scn.Convolution(dimension, nPlanes[0], nPlanes[1],
downsample[0], downsample[1], False)).add( downsample[0], downsample[1], False)).add(
U(nPlanes[1:])).add( U(nPlanes[1:])).add(
scn.BatchNormReLU(nPlanes[1])).add( scn.BatchNormLeakyReLU(nPlanes[1],leakiness=leakiness)).add(
scn.Deconvolution(dimension, nPlanes[1], nPlanes[0], scn.Deconvolution(dimension, nPlanes[1], nPlanes[0],
downsample[0], downsample[1], False)))) downsample[0], downsample[1], False))))
m.add(scn.JoinTable()) m.add(scn.JoinTable())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment