resnet_v2_test.py 20.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for slim.nets.resnet_v2."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf
23
from tensorflow.contrib import slim as contrib_slim
24
25
26
27

from nets import resnet_utils
from nets import resnet_v2

28
slim = contrib_slim
29

30
31
tf.compat.v1.disable_resource_variables()

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

def create_test_input(batch_size, height, width, channels):
  """Create test input tensor.

  Args:
    batch_size: The number of images per batch or `None` if unknown.
    height: The height of each image or `None` if unknown.
    width: The width of each image or `None` if unknown.
    channels: The number of channels per image or `None` if unknown.

  Returns:
    Either a placeholder `Tensor` of dimension
      [batch_size, height, width, channels] if any of the inputs are `None` or a
    constant `Tensor` with the mesh grid values along the spatial dimensions.
  """
  if None in [batch_size, height, width, channels]:
48
49
    return tf.compat.v1.placeholder(tf.float32,
                                    (batch_size, height, width, channels))
50
  else:
51
    return tf.cast(
52
53
54
55
        np.tile(
            np.reshape(
                np.reshape(np.arange(height), [height, 1]) +
                np.reshape(np.arange(width), [1, width]),
56
57
                [1, height, width, 1]), [batch_size, 1, 1, channels]),
        dtype=tf.float32)
58
59
60
61
62


class ResnetUtilsTest(tf.test.TestCase):

  def testSubsampleThreeByThree(self):
63
    x = tf.reshape(tf.cast(tf.range(9), dtype=tf.float32), [1, 3, 3, 1])
64
65
66
67
68
69
    x = resnet_utils.subsample(x, 2)
    expected = tf.reshape(tf.constant([0, 2, 6, 8]), [1, 2, 2, 1])
    with self.test_session():
      self.assertAllClose(x.eval(), expected.eval())

  def testSubsampleFourByFour(self):
70
    x = tf.reshape(tf.cast(tf.range(16), dtype=tf.float32), [1, 4, 4, 1])
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    x = resnet_utils.subsample(x, 2)
    expected = tf.reshape(tf.constant([0, 2, 8, 10]), [1, 2, 2, 1])
    with self.test_session():
      self.assertAllClose(x.eval(), expected.eval())

  def testConv2DSameEven(self):
    n, n2 = 4, 2

    # Input image.
    x = create_test_input(1, n, n, 1)

    # Convolution kernel.
    w = create_test_input(1, 3, 3, 1)
    w = tf.reshape(w, [3, 3, 1, 1])

86
87
88
    tf.compat.v1.get_variable('Conv/weights', initializer=w)
    tf.compat.v1.get_variable('Conv/biases', initializer=tf.zeros([1]))
    tf.compat.v1.get_variable_scope().reuse_variables()
89
90

    y1 = slim.conv2d(x, 1, [3, 3], stride=1, scope='Conv')
91
92
93
    y1_expected = tf.cast([[14, 28, 43, 26], [28, 48, 66, 37], [43, 66, 84, 46],
                           [26, 37, 46, 22]],
                          dtype=tf.float32)
94
95
96
    y1_expected = tf.reshape(y1_expected, [1, n, n, 1])

    y2 = resnet_utils.subsample(y1, 2)
97
    y2_expected = tf.cast([[14, 43], [43, 84]], dtype=tf.float32)
98
99
100
101
102
103
    y2_expected = tf.reshape(y2_expected, [1, n2, n2, 1])

    y3 = resnet_utils.conv2d_same(x, 1, 3, stride=2, scope='Conv')
    y3_expected = y2_expected

    y4 = slim.conv2d(x, 1, [3, 3], stride=2, scope='Conv')
104
    y4_expected = tf.cast([[48, 37], [37, 22]], dtype=tf.float32)
105
106
107
    y4_expected = tf.reshape(y4_expected, [1, n2, n2, 1])

    with self.test_session() as sess:
108
      sess.run(tf.compat.v1.global_variables_initializer())
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
      self.assertAllClose(y1.eval(), y1_expected.eval())
      self.assertAllClose(y2.eval(), y2_expected.eval())
      self.assertAllClose(y3.eval(), y3_expected.eval())
      self.assertAllClose(y4.eval(), y4_expected.eval())

  def testConv2DSameOdd(self):
    n, n2 = 5, 3

    # Input image.
    x = create_test_input(1, n, n, 1)

    # Convolution kernel.
    w = create_test_input(1, 3, 3, 1)
    w = tf.reshape(w, [3, 3, 1, 1])

124
125
126
    tf.compat.v1.get_variable('Conv/weights', initializer=w)
    tf.compat.v1.get_variable('Conv/biases', initializer=tf.zeros([1]))
    tf.compat.v1.get_variable_scope().reuse_variables()
127
128

    y1 = slim.conv2d(x, 1, [3, 3], stride=1, scope='Conv')
129
130
131
132
    y1_expected = tf.cast(
        [[14, 28, 43, 58, 34], [28, 48, 66, 84, 46], [43, 66, 84, 102, 55],
         [58, 84, 102, 120, 64], [34, 46, 55, 64, 30]],
        dtype=tf.float32)
133
134
135
    y1_expected = tf.reshape(y1_expected, [1, n, n, 1])

    y2 = resnet_utils.subsample(y1, 2)
136
137
    y2_expected = tf.cast([[14, 43, 34], [43, 84, 55], [34, 55, 30]],
                          dtype=tf.float32)
138
139
140
141
142
143
144
145
146
    y2_expected = tf.reshape(y2_expected, [1, n2, n2, 1])

    y3 = resnet_utils.conv2d_same(x, 1, 3, stride=2, scope='Conv')
    y3_expected = y2_expected

    y4 = slim.conv2d(x, 1, [3, 3], stride=2, scope='Conv')
    y4_expected = y2_expected

    with self.test_session() as sess:
147
      sess.run(tf.compat.v1.global_variables_initializer())
148
149
150
151
152
153
154
      self.assertAllClose(y1.eval(), y1_expected.eval())
      self.assertAllClose(y2.eval(), y2_expected.eval())
      self.assertAllClose(y3.eval(), y3_expected.eval())
      self.assertAllClose(y4.eval(), y4_expected.eval())

  def _resnet_plain(self, inputs, blocks, output_stride=None, scope=None):
    """A plain ResNet without extra layers before or after the ResNet blocks."""
155
    with tf.compat.v1.variable_scope(scope, values=[inputs]):
156
157
      with slim.arg_scope([slim.conv2d], outputs_collections='end_points'):
        net = resnet_utils.stack_blocks_dense(inputs, blocks, output_stride)
derekjchow's avatar
derekjchow committed
158
        end_points = slim.utils.convert_collection_to_dict('end_points')
159
160
161
162
        return net, end_points

  def testEndPointsV2(self):
    """Test the end points of a tiny v2 bottleneck network."""
derekjchow's avatar
derekjchow committed
163
164
165
166
167
168
    blocks = [
        resnet_v2.resnet_v2_block(
            'block1', base_depth=1, num_units=2, stride=2),
        resnet_v2.resnet_v2_block(
            'block2', base_depth=2, num_units=2, stride=1),
    ]
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    inputs = create_test_input(2, 32, 16, 3)
    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
      _, end_points = self._resnet_plain(inputs, blocks, scope='tiny')
    expected = [
        'tiny/block1/unit_1/bottleneck_v2/shortcut',
        'tiny/block1/unit_1/bottleneck_v2/conv1',
        'tiny/block1/unit_1/bottleneck_v2/conv2',
        'tiny/block1/unit_1/bottleneck_v2/conv3',
        'tiny/block1/unit_2/bottleneck_v2/conv1',
        'tiny/block1/unit_2/bottleneck_v2/conv2',
        'tiny/block1/unit_2/bottleneck_v2/conv3',
        'tiny/block2/unit_1/bottleneck_v2/shortcut',
        'tiny/block2/unit_1/bottleneck_v2/conv1',
        'tiny/block2/unit_1/bottleneck_v2/conv2',
        'tiny/block2/unit_1/bottleneck_v2/conv3',
        'tiny/block2/unit_2/bottleneck_v2/conv1',
        'tiny/block2/unit_2/bottleneck_v2/conv2',
        'tiny/block2/unit_2/bottleneck_v2/conv3']
pkulzc's avatar
pkulzc committed
187
    self.assertItemsEqual(expected, end_points.keys())
188
189
190
191

  def _stack_blocks_nondense(self, net, blocks):
    """A simplified ResNet Block stacker without output stride control."""
    for block in blocks:
192
      with tf.compat.v1.variable_scope(block.scope, 'block', [net]):
193
        for i, unit in enumerate(block.args):
194
          with tf.compat.v1.variable_scope('unit_%d' % (i + 1), values=[net]):
derekjchow's avatar
derekjchow committed
195
            net = block.unit_fn(net, rate=1, **unit)
196
197
    return net

derekjchow's avatar
derekjchow committed
198
  def testAtrousValuesBottleneck(self):
199
200
201
202
203
204
    """Verify the values of dense feature extraction by atrous convolution.

    Make sure that dense feature extraction by stack_blocks_dense() followed by
    subsampling gives identical results to feature extraction at the nominal
    network output stride using the simple self._stack_blocks_nondense() above.
    """
derekjchow's avatar
derekjchow committed
205
    block = resnet_v2.resnet_v2_block
206
    blocks = [
derekjchow's avatar
derekjchow committed
207
208
209
210
        block('block1', base_depth=1, num_units=2, stride=2),
        block('block2', base_depth=2, num_units=2, stride=2),
        block('block3', base_depth=4, num_units=2, stride=2),
        block('block4', base_depth=8, num_units=2, stride=1),
211
212
213
214
215
216
217
218
219
220
221
    ]
    nominal_stride = 8

    # Test both odd and even input dimensions.
    height = 30
    width = 31
    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
      with slim.arg_scope([slim.batch_norm], is_training=False):
        for output_stride in [1, 2, 4, 8, None]:
          with tf.Graph().as_default():
            with self.test_session() as sess:
222
              tf.compat.v1.set_random_seed(0)
223
224
225
226
227
228
229
230
231
232
233
234
              inputs = create_test_input(1, height, width, 3)
              # Dense feature extraction followed by subsampling.
              output = resnet_utils.stack_blocks_dense(inputs,
                                                       blocks,
                                                       output_stride)
              if output_stride is None:
                factor = 1
              else:
                factor = nominal_stride // output_stride

              output = resnet_utils.subsample(output, factor)
              # Make the two networks use the same weights.
235
              tf.compat.v1.get_variable_scope().reuse_variables()
236
237
              # Feature extraction at the nominal network rate.
              expected = self._stack_blocks_nondense(inputs, blocks)
238
              sess.run(tf.compat.v1.global_variables_initializer())
239
240
241
242
243
244
245
246
247
248
249
250
251
252
              output, expected = sess.run([output, expected])
              self.assertAllClose(output, expected, atol=1e-4, rtol=1e-4)


class ResnetCompleteNetworkTest(tf.test.TestCase):
  """Tests with complete small ResNet v2 networks."""

  def _resnet_small(self,
                    inputs,
                    num_classes=None,
                    is_training=True,
                    global_pool=True,
                    output_stride=None,
                    include_root_block=True,
Derek Chow's avatar
Derek Chow committed
253
                    spatial_squeeze=True,
254
255
256
                    reuse=None,
                    scope='resnet_v2_small'):
    """A shallow and thin ResNet v2 for faster tests."""
derekjchow's avatar
derekjchow committed
257
    block = resnet_v2.resnet_v2_block
258
    blocks = [
derekjchow's avatar
derekjchow committed
259
260
261
262
263
        block('block1', base_depth=1, num_units=3, stride=2),
        block('block2', base_depth=2, num_units=3, stride=2),
        block('block3', base_depth=4, num_units=3, stride=2),
        block('block4', base_depth=8, num_units=2, stride=1),
    ]
264
265
266
267
268
    return resnet_v2.resnet_v2(inputs, blocks, num_classes,
                               is_training=is_training,
                               global_pool=global_pool,
                               output_stride=output_stride,
                               include_root_block=include_root_block,
Derek Chow's avatar
Derek Chow committed
269
                               spatial_squeeze=spatial_squeeze,
270
271
272
273
274
275
276
277
278
279
                               reuse=reuse,
                               scope=scope)

  def testClassificationEndPoints(self):
    global_pool = True
    num_classes = 10
    inputs = create_test_input(2, 224, 224, 3)
    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
      logits, end_points = self._resnet_small(inputs, num_classes,
                                              global_pool=global_pool,
Derek Chow's avatar
Derek Chow committed
280
                                              spatial_squeeze=False,
281
282
283
284
285
286
                                              scope='resnet')
    self.assertTrue(logits.op.name.startswith('resnet/logits'))
    self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
    self.assertTrue('predictions' in end_points)
    self.assertListEqual(end_points['predictions'].get_shape().as_list(),
                         [2, 1, 1, num_classes])
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
    self.assertTrue('global_pool' in end_points)
    self.assertListEqual(end_points['global_pool'].get_shape().as_list(),
                         [2, 1, 1, 32])

  def testEndpointNames(self):
    # Like ResnetUtilsTest.testEndPointsV2(), but for the public API.
    global_pool = True
    num_classes = 10
    inputs = create_test_input(2, 224, 224, 3)
    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
      _, end_points = self._resnet_small(inputs, num_classes,
                                         global_pool=global_pool,
                                         scope='resnet')
    expected = ['resnet/conv1']
    for block in range(1, 5):
      for unit in range(1, 4 if block < 4 else 3):
        for conv in range(1, 4):
          expected.append('resnet/block%d/unit_%d/bottleneck_v2/conv%d' %
                          (block, unit, conv))
        expected.append('resnet/block%d/unit_%d/bottleneck_v2' % (block, unit))
      expected.append('resnet/block%d/unit_1/bottleneck_v2/shortcut' % block)
      expected.append('resnet/block%d' % block)
    expected.extend(['global_pool', 'resnet/logits', 'resnet/spatial_squeeze',
                     'predictions'])
    self.assertItemsEqual(end_points.keys(), expected)
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336

  def testClassificationShapes(self):
    global_pool = True
    num_classes = 10
    inputs = create_test_input(2, 224, 224, 3)
    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
      _, end_points = self._resnet_small(inputs, num_classes,
                                         global_pool=global_pool,
                                         scope='resnet')
      endpoint_to_shape = {
          'resnet/block1': [2, 28, 28, 4],
          'resnet/block2': [2, 14, 14, 8],
          'resnet/block3': [2, 7, 7, 16],
          'resnet/block4': [2, 7, 7, 32]}
      for endpoint in endpoint_to_shape:
        shape = endpoint_to_shape[endpoint]
        self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)

  def testFullyConvolutionalEndpointShapes(self):
    global_pool = False
    num_classes = 10
    inputs = create_test_input(2, 321, 321, 3)
    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
      _, end_points = self._resnet_small(inputs, num_classes,
                                         global_pool=global_pool,
Derek Chow's avatar
Derek Chow committed
337
                                         spatial_squeeze=False,
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
                                         scope='resnet')
      endpoint_to_shape = {
          'resnet/block1': [2, 41, 41, 4],
          'resnet/block2': [2, 21, 21, 8],
          'resnet/block3': [2, 11, 11, 16],
          'resnet/block4': [2, 11, 11, 32]}
      for endpoint in endpoint_to_shape:
        shape = endpoint_to_shape[endpoint]
        self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)

  def testRootlessFullyConvolutionalEndpointShapes(self):
    global_pool = False
    num_classes = 10
    inputs = create_test_input(2, 128, 128, 3)
    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
      _, end_points = self._resnet_small(inputs, num_classes,
                                         global_pool=global_pool,
                                         include_root_block=False,
Derek Chow's avatar
Derek Chow committed
356
                                         spatial_squeeze=False,
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
                                         scope='resnet')
      endpoint_to_shape = {
          'resnet/block1': [2, 64, 64, 4],
          'resnet/block2': [2, 32, 32, 8],
          'resnet/block3': [2, 16, 16, 16],
          'resnet/block4': [2, 16, 16, 32]}
      for endpoint in endpoint_to_shape:
        shape = endpoint_to_shape[endpoint]
        self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)

  def testAtrousFullyConvolutionalEndpointShapes(self):
    global_pool = False
    num_classes = 10
    output_stride = 8
    inputs = create_test_input(2, 321, 321, 3)
    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
      _, end_points = self._resnet_small(inputs,
                                         num_classes,
                                         global_pool=global_pool,
                                         output_stride=output_stride,
Derek Chow's avatar
Derek Chow committed
377
                                         spatial_squeeze=False,
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
                                         scope='resnet')
      endpoint_to_shape = {
          'resnet/block1': [2, 41, 41, 4],
          'resnet/block2': [2, 41, 41, 8],
          'resnet/block3': [2, 41, 41, 16],
          'resnet/block4': [2, 41, 41, 32]}
      for endpoint in endpoint_to_shape:
        shape = endpoint_to_shape[endpoint]
        self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)

  def testAtrousFullyConvolutionalValues(self):
    """Verify dense feature extraction with atrous convolution."""
    nominal_stride = 32
    for output_stride in [4, 8, 16, 32, None]:
      with slim.arg_scope(resnet_utils.resnet_arg_scope()):
        with tf.Graph().as_default():
          with self.test_session() as sess:
395
            tf.compat.v1.set_random_seed(0)
396
397
398
399
400
401
402
403
404
405
406
407
            inputs = create_test_input(2, 81, 81, 3)
            # Dense feature extraction followed by subsampling.
            output, _ = self._resnet_small(inputs, None,
                                           is_training=False,
                                           global_pool=False,
                                           output_stride=output_stride)
            if output_stride is None:
              factor = 1
            else:
              factor = nominal_stride // output_stride
            output = resnet_utils.subsample(output, factor)
            # Make the two networks use the same weights.
408
            tf.compat.v1.get_variable_scope().reuse_variables()
409
410
411
412
            # Feature extraction at the nominal network rate.
            expected, _ = self._resnet_small(inputs, None,
                                             is_training=False,
                                             global_pool=False)
413
            sess.run(tf.compat.v1.global_variables_initializer())
414
415
416
417
418
419
420
421
422
423
424
425
            self.assertAllClose(output.eval(), expected.eval(),
                                atol=1e-4, rtol=1e-4)

  def testUnknownBatchSize(self):
    batch = 2
    height, width = 65, 65
    global_pool = True
    num_classes = 10
    inputs = create_test_input(None, height, width, 3)
    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
      logits, _ = self._resnet_small(inputs, num_classes,
                                     global_pool=global_pool,
Derek Chow's avatar
Derek Chow committed
426
                                     spatial_squeeze=False,
427
428
429
430
431
432
                                     scope='resnet')
    self.assertTrue(logits.op.name.startswith('resnet/logits'))
    self.assertListEqual(logits.get_shape().as_list(),
                         [None, 1, 1, num_classes])
    images = create_test_input(batch, height, width, 3)
    with self.test_session() as sess:
433
      sess.run(tf.compat.v1.global_variables_initializer())
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
      output = sess.run(logits, {inputs: images.eval()})
      self.assertEqual(output.shape, (batch, 1, 1, num_classes))

  def testFullyConvolutionalUnknownHeightWidth(self):
    batch = 2
    height, width = 65, 65
    global_pool = False
    inputs = create_test_input(batch, None, None, 3)
    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
      output, _ = self._resnet_small(inputs, None,
                                     global_pool=global_pool)
    self.assertListEqual(output.get_shape().as_list(),
                         [batch, None, None, 32])
    images = create_test_input(batch, height, width, 3)
    with self.test_session() as sess:
449
      sess.run(tf.compat.v1.global_variables_initializer())
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
      output = sess.run(output, {inputs: images.eval()})
      self.assertEqual(output.shape, (batch, 3, 3, 32))

  def testAtrousFullyConvolutionalUnknownHeightWidth(self):
    batch = 2
    height, width = 65, 65
    global_pool = False
    output_stride = 8
    inputs = create_test_input(batch, None, None, 3)
    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
      output, _ = self._resnet_small(inputs,
                                     None,
                                     global_pool=global_pool,
                                     output_stride=output_stride)
    self.assertListEqual(output.get_shape().as_list(),
                         [batch, None, None, 32])
    images = create_test_input(batch, height, width, 3)
    with self.test_session() as sess:
468
      sess.run(tf.compat.v1.global_variables_initializer())
469
470
471
472
473
474
      output = sess.run(output, {inputs: images.eval()})
      self.assertEqual(output.shape, (batch, 9, 9, 32))


if __name__ == '__main__':
  tf.test.main()