Commit e488fe04 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

fix

parent 54c58b5f
# Copyright 2016-present, Facebook, Inc. # Copyright 2g016-present, Facebook, Inc.
# All rights reserved. # All rights reserved.
# #
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
...@@ -218,8 +218,6 @@ def UNet(dimension, reps, nPlanes, residual_blocks=False, downsample=[2, 2], lea ...@@ -218,8 +218,6 @@ def UNet(dimension, reps, nPlanes, residual_blocks=False, downsample=[2, 2], lea
x=self.linear(x) x=self.linear(x)
return x return x
""" """
if n_input_planes==-1:
n_input_planes=nPlanes[0]
def block(m, a, b): def block(m, a, b):
if residual_blocks: #ResNet style blocks if residual_blocks: #ResNet style blocks
m.add(scn.ConcatTable() m.add(scn.ConcatTable()
...@@ -234,10 +232,11 @@ def UNet(dimension, reps, nPlanes, residual_blocks=False, downsample=[2, 2], lea ...@@ -234,10 +232,11 @@ def UNet(dimension, reps, nPlanes, residual_blocks=False, downsample=[2, 2], lea
m.add(scn.Sequential() m.add(scn.Sequential()
.add(scn.BatchNormLeakyReLU(a,leakiness=leakiness)) .add(scn.BatchNormLeakyReLU(a,leakiness=leakiness))
.add(scn.SubmanifoldConvolution(dimension, a, b, 3, False))) .add(scn.SubmanifoldConvolution(dimension, a, b, 3, False)))
def U(nPlanes): #Recursive function def U(nPlanes,n_input_planes=-1): #Recursive function
m = scn.Sequential() m = scn.Sequential()
for i in range(reps): for i in range(reps):
block(m, n_input_planes if i==0 else nPlanes[0], nPlanes[0]) block(m, n_input_planes if n_input_planes!=-1 else nPlanes[0], nPlanes[0])
n_input_planes=-1
if len(nPlanes) > 1: if len(nPlanes) > 1:
m.add( m.add(
scn.ConcatTable().add( scn.ConcatTable().add(
...@@ -254,7 +253,7 @@ def UNet(dimension, reps, nPlanes, residual_blocks=False, downsample=[2, 2], lea ...@@ -254,7 +253,7 @@ def UNet(dimension, reps, nPlanes, residual_blocks=False, downsample=[2, 2], lea
for i in range(reps): for i in range(reps):
block(m, nPlanes[0] * (2 if i == 0 else 1), nPlanes[0]) block(m, nPlanes[0] * (2 if i == 0 else 1), nPlanes[0])
return m return m
m = U(nPlanes) m = U(nPlanes,n_input_planes)
return m return m
def FullyConvolutionalNet(dimension, reps, nPlanes, residual_blocks=False, downsample=[2, 2]): def FullyConvolutionalNet(dimension, reps, nPlanes, residual_blocks=False, downsample=[2, 2]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment