Unverified Commit d5663b3c authored by Karmel Allison's avatar Karmel Allison Committed by GitHub
Browse files

Use reduce_mean instead of Average Pooling in Resnet (#3675)

* Try out affine layer instead of dense

* Use reduce mean instead of avg pooling

* Remove np

* Use reduce mean instead of avg pooling

* Fix axes

* Cleanup

* Fixing comment

* Fixing tests
parent 5ef68f57
...@@ -64,7 +64,7 @@ class BaseTest(tf.test.TestCase): ...@@ -64,7 +64,7 @@ class BaseTest(tf.test.TestCase):
block_layer2 = graph.get_tensor_by_name('block_layer2:0') block_layer2 = graph.get_tensor_by_name('block_layer2:0')
block_layer3 = graph.get_tensor_by_name('block_layer3:0') block_layer3 = graph.get_tensor_by_name('block_layer3:0')
block_layer4 = graph.get_tensor_by_name('block_layer4:0') block_layer4 = graph.get_tensor_by_name('block_layer4:0')
avg_pool = graph.get_tensor_by_name('final_avg_pool:0') reduce_mean = graph.get_tensor_by_name('final_reduce_mean:0')
dense = graph.get_tensor_by_name('final_dense:0') dense = graph.get_tensor_by_name('final_dense:0')
self.assertAllEqual(initial_conv.shape, reshape((1, 64, 112, 112))) self.assertAllEqual(initial_conv.shape, reshape((1, 64, 112, 112)))
...@@ -77,13 +77,13 @@ class BaseTest(tf.test.TestCase): ...@@ -77,13 +77,13 @@ class BaseTest(tf.test.TestCase):
self.assertAllEqual(block_layer2.shape, reshape((1, 128, 28, 28))) self.assertAllEqual(block_layer2.shape, reshape((1, 128, 28, 28)))
self.assertAllEqual(block_layer3.shape, reshape((1, 256, 14, 14))) self.assertAllEqual(block_layer3.shape, reshape((1, 256, 14, 14)))
self.assertAllEqual(block_layer4.shape, reshape((1, 512, 7, 7))) self.assertAllEqual(block_layer4.shape, reshape((1, 512, 7, 7)))
self.assertAllEqual(avg_pool.shape, reshape((1, 512, 1, 1))) self.assertAllEqual(reduce_mean.shape, reshape((1, 512, 1, 1)))
else: else:
self.assertAllEqual(block_layer1.shape, reshape((1, 256, 56, 56))) self.assertAllEqual(block_layer1.shape, reshape((1, 256, 56, 56)))
self.assertAllEqual(block_layer2.shape, reshape((1, 512, 28, 28))) self.assertAllEqual(block_layer2.shape, reshape((1, 512, 28, 28)))
self.assertAllEqual(block_layer3.shape, reshape((1, 1024, 14, 14))) self.assertAllEqual(block_layer3.shape, reshape((1, 1024, 14, 14)))
self.assertAllEqual(block_layer4.shape, reshape((1, 2048, 7, 7))) self.assertAllEqual(block_layer4.shape, reshape((1, 2048, 7, 7)))
self.assertAllEqual(avg_pool.shape, reshape((1, 2048, 1, 1))) self.assertAllEqual(reduce_mean.shape, reshape((1, 2048, 1, 1)))
self.assertAllEqual(dense.shape, (1, _LABEL_CLASSES)) self.assertAllEqual(dense.shape, (1, _LABEL_CLASSES))
self.assertAllEqual(output.shape, (1, _LABEL_CLASSES)) self.assertAllEqual(output.shape, (1, _LABEL_CLASSES))
......
...@@ -31,10 +31,8 @@ from __future__ import absolute_import ...@@ -31,10 +31,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow as tf
_BATCH_NORM_DECAY = 0.997 _BATCH_NORM_DECAY = 0.997
_BATCH_NORM_EPSILON = 1e-5 _BATCH_NORM_EPSILON = 1e-5
DEFAULT_VERSION = 2 DEFAULT_VERSION = 2
...@@ -461,13 +459,18 @@ class Model(object): ...@@ -461,13 +459,18 @@ class Model(object):
inputs = batch_norm(inputs, training, self.data_format) inputs = batch_norm(inputs, training, self.data_format)
inputs = tf.nn.relu(inputs) inputs = tf.nn.relu(inputs)
inputs = tf.layers.average_pooling2d(
inputs=inputs, pool_size=self.second_pool_size, # The current top layer has shape
strides=self.second_pool_stride, padding='VALID', # `batch_size x pool_size x pool_size x final_size`.
data_format=self.data_format) # ResNet does an Average Pooling layer over pool_size,
inputs = tf.identity(inputs, 'final_avg_pool') # but that is the same as doing a reduce_mean. We do a reduce_mean
# here because it performs better than AveragePooling2D.
axes = [2, 3] if self.data_format == 'channels_first' else [1, 2]
inputs = tf.reduce_mean(inputs, axes, keepdims=True)
inputs = tf.identity(inputs, 'final_reduce_mean')
inputs = tf.reshape(inputs, [-1, self.final_size]) inputs = tf.reshape(inputs, [-1, self.final_size])
inputs = tf.layers.dense(inputs=inputs, units=self.num_classes) inputs = tf.layers.dense(inputs=inputs, units=self.num_classes)
inputs = tf.identity(inputs, 'final_dense') inputs = tf.identity(inputs, 'final_dense')
return inputs return inputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment