Commit 8c706edd authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

Leaky UNets

parent 55d55a6a
......@@ -80,15 +80,12 @@ class BatchNormReLU(BatchNormalization):
class BatchNormLeakyReLU(BatchNormalization):
def __init__(self, nPlanes, eps=1e-4, momentum=0.9):
BatchNormalization.__init__(self, nPlanes, eps, momentum, True, 0.333)
def __init__(self, nPlanes, eps=1e-4, momentum=0.9, leakiness=0.333):
BatchNormalization.__init__(self, nPlanes, eps, momentum, True, leakiness)
def __repr__(self):
s = 'BatchLeakyNorm(' + str(self.nPlanes) + ',eps=' + str(self.eps) + \
',momentum=' + str(self.momentum) + ',affine=' + str(self.affine)
if self.leakiness > 0:
s = s + ',leakiness=' + str(self.leakiness)
s = s + ')'
s = 'BatchNormLeakyReLU(' + str(self.nPlanes) + ',eps=' + str(self.eps) + \
',momentum=' + str(self.momentum) + ',affine=' + str(self.affine) + ',leakiness='+str(self.leakiness)+')'
return s
class BatchNormalizationFunction(Function):
......
......@@ -200,7 +200,7 @@ def SparseResNet(dimension, nInputPlanes, layers):
return m
def UNet(dimension, reps, nPlanes, residual_blocks=False, downsample=[2, 2]):
def UNet(dimension, reps, nPlanes, residual_blocks=False, downsample=[2, 2], leakiness=0):
"""
U-Net style network with VGG or ResNet-style blocks.
For voxel level prediction:
......@@ -223,14 +223,14 @@ def UNet(dimension, reps, nPlanes, residual_blocks=False, downsample=[2, 2]):
m.add(scn.ConcatTable()
.add(scn.Identity() if a == b else scn.NetworkInNetwork(a, b, False))
.add(scn.Sequential()
.add(scn.BatchNormReLU(a))
.add(scn.BatchNormLeakyReLU(a,leakiness=leakiness))
.add(scn.SubmanifoldConvolution(dimension, a, b, 3, False))
.add(scn.BatchNormReLU(b))
.add(scn.BatchNormLeakyReLU(b,leakiness=leakiness))
.add(scn.SubmanifoldConvolution(dimension, b, b, 3, False)))
).add(scn.AddTable())
else: #VGG style blocks
m.add(scn.Sequential()
.add(scn.BatchNormReLU(a))
.add(scn.BatchNormLeakyReLU(a,leakiness=leakiness))
.add(scn.SubmanifoldConvolution(dimension, a, b, 3, False)))
def U(nPlanes): #Recursive function
m = scn.Sequential()
......@@ -245,11 +245,11 @@ def UNet(dimension, reps, nPlanes, residual_blocks=False, downsample=[2, 2]):
scn.ConcatTable().add(
scn.Identity()).add(
scn.Sequential().add(
scn.BatchNormReLU(nPlanes[0])).add(
scn.BatchNormLeakyReLU(nPlanes[0],leakiness=leakiness)).add(
scn.Convolution(dimension, nPlanes[0], nPlanes[1],
downsample[0], downsample[1], False)).add(
U(nPlanes[1:])).add(
scn.BatchNormReLU(nPlanes[1])).add(
scn.BatchNormLeakyReLU(nPlanes[1],leakiness=leakiness)).add(
scn.Deconvolution(dimension, nPlanes[1], nPlanes[0],
downsample[0], downsample[1], False))))
m.add(scn.JoinTable())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment