Commit 7818c255 authored by Andrew Hundt's avatar Andrew Hundt
Browse files

resnet_v1 segmentation bugs from 7e2435e5 resolved

parent 0d961be2
...@@ -202,6 +202,8 @@ def resnet_v1(inputs, ...@@ -202,6 +202,8 @@ def resnet_v1(inputs,
normalizer_fn=None, scope='logits') normalizer_fn=None, scope='logits')
if spatial_squeeze: if spatial_squeeze:
logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze') logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
else:
logits = net
# Convert end_points_collection into a dictionary of end_points. # Convert end_points_collection into a dictionary of end_points.
end_points = slim.utils.convert_collection_to_dict(end_points_collection) end_points = slim.utils.convert_collection_to_dict(end_points_collection)
if num_classes is not None: if num_classes is not None:
...@@ -215,6 +217,7 @@ def resnet_v1_50(inputs, ...@@ -215,6 +217,7 @@ def resnet_v1_50(inputs,
is_training=True, is_training=True,
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
spatial_squeeze=True,
reuse=None, reuse=None,
scope='resnet_v1_50'): scope='resnet_v1_50'):
"""ResNet-50 model of [1]. See resnet_v1() for arg and return description.""" """ResNet-50 model of [1]. See resnet_v1() for arg and return description."""
...@@ -230,7 +233,8 @@ def resnet_v1_50(inputs, ...@@ -230,7 +233,8 @@ def resnet_v1_50(inputs,
] ]
return resnet_v1(inputs, blocks, num_classes, is_training, return resnet_v1(inputs, blocks, num_classes, is_training,
global_pool=global_pool, output_stride=output_stride, global_pool=global_pool, output_stride=output_stride,
include_root_block=True, reuse=reuse, scope=scope) include_root_block=True, spatial_squeeze=spatial_squeeze,
reuse=reuse, scope=scope)
resnet_v1_50.default_image_size = resnet_v1.default_image_size resnet_v1_50.default_image_size = resnet_v1.default_image_size
...@@ -239,6 +243,7 @@ def resnet_v1_101(inputs, ...@@ -239,6 +243,7 @@ def resnet_v1_101(inputs,
is_training=True, is_training=True,
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
spatial_squeeze=True,
reuse=None, reuse=None,
scope='resnet_v1_101'): scope='resnet_v1_101'):
"""ResNet-101 model of [1]. See resnet_v1() for arg and return description.""" """ResNet-101 model of [1]. See resnet_v1() for arg and return description."""
...@@ -254,7 +259,8 @@ def resnet_v1_101(inputs, ...@@ -254,7 +259,8 @@ def resnet_v1_101(inputs,
] ]
return resnet_v1(inputs, blocks, num_classes, is_training, return resnet_v1(inputs, blocks, num_classes, is_training,
global_pool=global_pool, output_stride=output_stride, global_pool=global_pool, output_stride=output_stride,
include_root_block=True, reuse=reuse, scope=scope) include_root_block=True, spatial_squeeze=spatial_squeeze,
reuse=reuse, scope=scope)
resnet_v1_101.default_image_size = resnet_v1.default_image_size resnet_v1_101.default_image_size = resnet_v1.default_image_size
...@@ -263,6 +269,7 @@ def resnet_v1_152(inputs, ...@@ -263,6 +269,7 @@ def resnet_v1_152(inputs,
is_training=True, is_training=True,
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
spatial_squeeze=True,
reuse=None, reuse=None,
scope='resnet_v1_152'): scope='resnet_v1_152'):
"""ResNet-152 model of [1]. See resnet_v1() for arg and return description.""" """ResNet-152 model of [1]. See resnet_v1() for arg and return description."""
...@@ -277,7 +284,8 @@ def resnet_v1_152(inputs, ...@@ -277,7 +284,8 @@ def resnet_v1_152(inputs,
'block4', bottleneck, [(2048, 512, 1)] * 3)] 'block4', bottleneck, [(2048, 512, 1)] * 3)]
return resnet_v1(inputs, blocks, num_classes, is_training, return resnet_v1(inputs, blocks, num_classes, is_training,
global_pool=global_pool, output_stride=output_stride, global_pool=global_pool, output_stride=output_stride,
include_root_block=True, reuse=reuse, scope=scope) include_root_block=True, spatial_squeeze=spatial_squeeze,
reuse=reuse, scope=scope)
resnet_v1_152.default_image_size = resnet_v1.default_image_size resnet_v1_152.default_image_size = resnet_v1.default_image_size
...@@ -286,6 +294,7 @@ def resnet_v1_200(inputs, ...@@ -286,6 +294,7 @@ def resnet_v1_200(inputs,
is_training=True, is_training=True,
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
spatial_squeeze=True,
reuse=None, reuse=None,
scope='resnet_v1_200'): scope='resnet_v1_200'):
"""ResNet-200 model of [2]. See resnet_v1() for arg and return description.""" """ResNet-200 model of [2]. See resnet_v1() for arg and return description."""
...@@ -300,5 +309,6 @@ def resnet_v1_200(inputs, ...@@ -300,5 +309,6 @@ def resnet_v1_200(inputs,
'block4', bottleneck, [(2048, 512, 1)] * 3)] 'block4', bottleneck, [(2048, 512, 1)] * 3)]
return resnet_v1(inputs, blocks, num_classes, is_training, return resnet_v1(inputs, blocks, num_classes, is_training,
global_pool=global_pool, output_stride=output_stride, global_pool=global_pool, output_stride=output_stride,
include_root_block=True, reuse=reuse, scope=scope) include_root_block=True, spatial_squeeze=spatial_squeeze,
reuse=reuse, scope=scope)
resnet_v1_200.default_image_size = resnet_v1.default_image_size resnet_v1_200.default_image_size = resnet_v1.default_image_size
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment