Commit 9c17823e authored by derekjchow's avatar derekjchow Committed by Sergio Guadarrama
Browse files

Add comment clarifying spatial squeeze. (#1613)

parent b4968012
...@@ -161,6 +161,9 @@ def resnet_v1(inputs, ...@@ -161,6 +161,9 @@ def resnet_v1(inputs,
max-pooling, if False excludes it. max-pooling, if False excludes it.
spatial_squeeze: if True, logits is of shape [B, C], if false logits is spatial_squeeze: if True, logits is of shape [B, C], if false logits is
of shape [B, 1, 1, C], where B is batch_size and C is number of classes. of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
To use this parameter, the input images must be smaller than 300x300
pixels, in which case the output logit layer does not contain spatial
information and can be removed.
reuse: whether or not the network and its variables should be reused. To be reuse: whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given. able to reuse 'scope' must be given.
scope: Optional variable_scope. scope: Optional variable_scope.
...@@ -200,16 +203,14 @@ def resnet_v1(inputs, ...@@ -200,16 +203,14 @@ def resnet_v1(inputs,
if num_classes is not None: if num_classes is not None:
net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
normalizer_fn=None, scope='logits') normalizer_fn=None, scope='logits')
if spatial_squeeze: if spatial_squeeze:
logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze') net = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
else:
logits = net
# Convert end_points_collection into a dictionary of end_points. # Convert end_points_collection into a dictionary of end_points.
end_points = slim.utils.convert_collection_to_dict( end_points = slim.utils.convert_collection_to_dict(
end_points_collection) end_points_collection)
if num_classes is not None: if num_classes is not None:
end_points['predictions'] = slim.softmax(logits, scope='predictions') end_points['predictions'] = slim.softmax(net, scope='predictions')
return logits, end_points return net, end_points
resnet_v1.default_image_size = 224 resnet_v1.default_image_size = 224
......
...@@ -158,6 +158,9 @@ def resnet_v2(inputs, ...@@ -158,6 +158,9 @@ def resnet_v2(inputs,
results of an activation-less convolution. results of an activation-less convolution.
spatial_squeeze: if True, logits is of shape [B, C], if false logits is spatial_squeeze: if True, logits is of shape [B, C], if false logits is
of shape [B, 1, 1, C], where B is batch_size and C is number of classes. of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
To use this parameter, the input images must be smaller than 300x300
pixels, in which case the output logit layer does not contain spatial
information and can be removed.
reuse: whether or not the network and its variables should be reused. To be reuse: whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given. able to reuse 'scope' must be given.
scope: Optional variable_scope. scope: Optional variable_scope.
...@@ -207,16 +210,14 @@ def resnet_v2(inputs, ...@@ -207,16 +210,14 @@ def resnet_v2(inputs,
if num_classes is not None: if num_classes is not None:
net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
normalizer_fn=None, scope='logits') normalizer_fn=None, scope='logits')
if spatial_squeeze: if spatial_squeeze:
logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze') net = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
else:
logits = net
# Convert end_points_collection into a dictionary of end_points. # Convert end_points_collection into a dictionary of end_points.
end_points = slim.utils.convert_collection_to_dict( end_points = slim.utils.convert_collection_to_dict(
end_points_collection) end_points_collection)
if num_classes is not None: if num_classes is not None:
end_points['predictions'] = slim.softmax(logits, scope='predictions') end_points['predictions'] = slim.softmax(net, scope='predictions')
return logits, end_points return net, end_points
resnet_v2.default_image_size = 224 resnet_v2.default_image_size = 224
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment