Commit 9dd2c618 authored by Neal Wu's avatar Neal Wu
Browse files

Perform the squeeze in a more appropriate location

parent a9d0e6e8
...@@ -206,12 +206,12 @@ def resnet_v2(inputs, ...@@ -206,12 +206,12 @@ def resnet_v2(inputs,
if num_classes is not None: if num_classes is not None:
net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
normalizer_fn=None, scope='logits') normalizer_fn=None, scope='logits')
logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
# Convert end_points_collection into a dictionary of end_points. # Convert end_points_collection into a dictionary of end_points.
end_points = slim.utils.convert_collection_to_dict(end_points_collection) end_points = slim.utils.convert_collection_to_dict(end_points_collection)
if num_classes is not None: if num_classes is not None:
end_points['predictions'] = slim.softmax(net, scope='predictions') end_points['predictions'] = slim.softmax(net, scope='predictions')
return net, end_points return logits, end_points
resnet_v2.default_image_size = 224
def resnet_v2_50(inputs, def resnet_v2_50(inputs,
......
...@@ -473,7 +473,7 @@ def main(_): ...@@ -473,7 +473,7 @@ def main(_):
end_points['AuxLogits'], labels, end_points['AuxLogits'], labels,
label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss') label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss')
tf.losses.softmax_cross_entropy( tf.losses.softmax_cross_entropy(
tf.squeeze(logits), labels, label_smoothing=FLAGS.label_smoothing, weights=1.0) logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0)
return end_points return end_points
# Gather initial summaries. # Gather initial summaries.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment