"magic_pdf/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "b2790f6f45d8d4530cc82db7ac8e1537bcec5203"
Commit cb4a4853 authored by Andrew Hundt's avatar Andrew Hundt
Browse files

resnet_v2.py segmentation bugs from 7e2435e5 resolved

parent 7818c255
...@@ -211,6 +211,8 @@ def resnet_v2(inputs, ...@@ -211,6 +211,8 @@ def resnet_v2(inputs,
normalizer_fn=None, scope='logits') normalizer_fn=None, scope='logits')
if spatial_squeeze: if spatial_squeeze:
logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze') logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
else:
logits = net
# Convert end_points_collection into a dictionary of end_points. # Convert end_points_collection into a dictionary of end_points.
end_points = slim.utils.convert_collection_to_dict(end_points_collection) end_points = slim.utils.convert_collection_to_dict(end_points_collection)
if num_classes is not None: if num_classes is not None:
...@@ -224,6 +226,7 @@ def resnet_v2_50(inputs, ...@@ -224,6 +226,7 @@ def resnet_v2_50(inputs,
is_training=True, is_training=True,
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
spatial_squeeze=True,
reuse=None, reuse=None,
scope='resnet_v2_50'): scope='resnet_v2_50'):
"""ResNet-50 model of [1]. See resnet_v2() for arg and return description.""" """ResNet-50 model of [1]. See resnet_v2() for arg and return description."""
...@@ -238,7 +241,8 @@ def resnet_v2_50(inputs, ...@@ -238,7 +241,8 @@ def resnet_v2_50(inputs,
'block4', bottleneck, [(2048, 512, 1)] * 3)] 'block4', bottleneck, [(2048, 512, 1)] * 3)]
return resnet_v2(inputs, blocks, num_classes, is_training=is_training, return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
global_pool=global_pool, output_stride=output_stride, global_pool=global_pool, output_stride=output_stride,
include_root_block=True, reuse=reuse, scope=scope) include_root_block=True, spatial_squeeze=spatial_squeeze,
reuse=reuse, scope=scope)
resnet_v2_50.default_image_size = resnet_v2.default_image_size resnet_v2_50.default_image_size = resnet_v2.default_image_size
...@@ -247,6 +251,7 @@ def resnet_v2_101(inputs, ...@@ -247,6 +251,7 @@ def resnet_v2_101(inputs,
is_training=True, is_training=True,
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
spatial_squeeze=True,
reuse=None, reuse=None,
scope='resnet_v2_101'): scope='resnet_v2_101'):
"""ResNet-101 model of [1]. See resnet_v2() for arg and return description.""" """ResNet-101 model of [1]. See resnet_v2() for arg and return description."""
...@@ -261,7 +266,8 @@ def resnet_v2_101(inputs, ...@@ -261,7 +266,8 @@ def resnet_v2_101(inputs,
'block4', bottleneck, [(2048, 512, 1)] * 3)] 'block4', bottleneck, [(2048, 512, 1)] * 3)]
return resnet_v2(inputs, blocks, num_classes, is_training=is_training, return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
global_pool=global_pool, output_stride=output_stride, global_pool=global_pool, output_stride=output_stride,
include_root_block=True, reuse=reuse, scope=scope) include_root_block=True, spatial_squeeze=spatial_squeeze,
reuse=reuse, scope=scope)
resnet_v2_101.default_image_size = resnet_v2.default_image_size resnet_v2_101.default_image_size = resnet_v2.default_image_size
...@@ -270,6 +276,7 @@ def resnet_v2_152(inputs, ...@@ -270,6 +276,7 @@ def resnet_v2_152(inputs,
is_training=True, is_training=True,
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
spatial_squeeze=True,
reuse=None, reuse=None,
scope='resnet_v2_152'): scope='resnet_v2_152'):
"""ResNet-152 model of [1]. See resnet_v2() for arg and return description.""" """ResNet-152 model of [1]. See resnet_v2() for arg and return description."""
...@@ -284,7 +291,8 @@ def resnet_v2_152(inputs, ...@@ -284,7 +291,8 @@ def resnet_v2_152(inputs,
'block4', bottleneck, [(2048, 512, 1)] * 3)] 'block4', bottleneck, [(2048, 512, 1)] * 3)]
return resnet_v2(inputs, blocks, num_classes, is_training=is_training, return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
global_pool=global_pool, output_stride=output_stride, global_pool=global_pool, output_stride=output_stride,
include_root_block=True, reuse=reuse, scope=scope) include_root_block=True, spatial_squeeze=spatial_squeeze,
reuse=reuse, scope=scope)
resnet_v2_152.default_image_size = resnet_v2.default_image_size resnet_v2_152.default_image_size = resnet_v2.default_image_size
...@@ -293,6 +301,7 @@ def resnet_v2_200(inputs, ...@@ -293,6 +301,7 @@ def resnet_v2_200(inputs,
is_training=True, is_training=True,
global_pool=True, global_pool=True,
output_stride=None, output_stride=None,
spatial_squeeze=True,
reuse=None, reuse=None,
scope='resnet_v2_200'): scope='resnet_v2_200'):
"""ResNet-200 model of [2]. See resnet_v2() for arg and return description.""" """ResNet-200 model of [2]. See resnet_v2() for arg and return description."""
...@@ -307,5 +316,6 @@ def resnet_v2_200(inputs, ...@@ -307,5 +316,6 @@ def resnet_v2_200(inputs,
'block4', bottleneck, [(2048, 512, 1)] * 3)] 'block4', bottleneck, [(2048, 512, 1)] * 3)]
return resnet_v2(inputs, blocks, num_classes, is_training=is_training, return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
global_pool=global_pool, output_stride=output_stride, global_pool=global_pool, output_stride=output_stride,
include_root_block=True, reuse=reuse, scope=scope) include_root_block=True, spatial_squeeze=spatial_squeeze,
reuse=reuse, scope=scope)
resnet_v2_200.default_image_size = resnet_v2.default_image_size resnet_v2_200.default_image_size = resnet_v2.default_image_size
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment