nasnet_test.py 19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for slim.nasnet."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
21
from tensorflow.contrib import slim as contrib_slim
22
23
24

from nets.nasnet import nasnet

25
slim = contrib_slim
26
27
28
29
30
31
32
33


class NASNetTest(tf.test.TestCase):

  def testBuildLogitsCifarModel(self):
    batch_size = 5
    height, width = 32, 32
    num_classes = 10
34
35
    inputs = tf.random.uniform((batch_size, height, width, 3))
    tf.compat.v1.train.create_global_step()
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    with slim.arg_scope(nasnet.nasnet_cifar_arg_scope()):
      logits, end_points = nasnet.build_nasnet_cifar(inputs, num_classes)
    auxlogits = end_points['AuxLogits']
    predictions = end_points['Predictions']
    self.assertListEqual(auxlogits.get_shape().as_list(),
                         [batch_size, num_classes])
    self.assertListEqual(logits.get_shape().as_list(),
                         [batch_size, num_classes])
    self.assertListEqual(predictions.get_shape().as_list(),
                         [batch_size, num_classes])

  def testBuildLogitsMobileModel(self):
    batch_size = 5
    height, width = 224, 224
    num_classes = 1000
51
52
    inputs = tf.random.uniform((batch_size, height, width, 3))
    tf.compat.v1.train.create_global_step()
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
      logits, end_points = nasnet.build_nasnet_mobile(inputs, num_classes)
    auxlogits = end_points['AuxLogits']
    predictions = end_points['Predictions']
    self.assertListEqual(auxlogits.get_shape().as_list(),
                         [batch_size, num_classes])
    self.assertListEqual(logits.get_shape().as_list(),
                         [batch_size, num_classes])
    self.assertListEqual(predictions.get_shape().as_list(),
                         [batch_size, num_classes])

  def testBuildLogitsLargeModel(self):
    batch_size = 5
    height, width = 331, 331
    num_classes = 1000
68
69
    inputs = tf.random.uniform((batch_size, height, width, 3))
    tf.compat.v1.train.create_global_step()
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    with slim.arg_scope(nasnet.nasnet_large_arg_scope()):
      logits, end_points = nasnet.build_nasnet_large(inputs, num_classes)
    auxlogits = end_points['AuxLogits']
    predictions = end_points['Predictions']
    self.assertListEqual(auxlogits.get_shape().as_list(),
                         [batch_size, num_classes])
    self.assertListEqual(logits.get_shape().as_list(),
                         [batch_size, num_classes])
    self.assertListEqual(predictions.get_shape().as_list(),
                         [batch_size, num_classes])

  def testBuildPreLogitsCifarModel(self):
    batch_size = 5
    height, width = 32, 32
    num_classes = None
85
86
    inputs = tf.random.uniform((batch_size, height, width, 3))
    tf.compat.v1.train.create_global_step()
87
88
89
90
91
92
93
94
95
96
97
    with slim.arg_scope(nasnet.nasnet_cifar_arg_scope()):
      net, end_points = nasnet.build_nasnet_cifar(inputs, num_classes)
    self.assertFalse('AuxLogits' in end_points)
    self.assertFalse('Predictions' in end_points)
    self.assertTrue(net.op.name.startswith('final_layer/Mean'))
    self.assertListEqual(net.get_shape().as_list(), [batch_size, 768])

  def testBuildPreLogitsMobileModel(self):
    batch_size = 5
    height, width = 224, 224
    num_classes = None
98
99
    inputs = tf.random.uniform((batch_size, height, width, 3))
    tf.compat.v1.train.create_global_step()
100
101
102
103
104
105
106
107
108
109
110
    with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
      net, end_points = nasnet.build_nasnet_mobile(inputs, num_classes)
    self.assertFalse('AuxLogits' in end_points)
    self.assertFalse('Predictions' in end_points)
    self.assertTrue(net.op.name.startswith('final_layer/Mean'))
    self.assertListEqual(net.get_shape().as_list(), [batch_size, 1056])

  def testBuildPreLogitsLargeModel(self):
    batch_size = 5
    height, width = 331, 331
    num_classes = None
111
112
    inputs = tf.random.uniform((batch_size, height, width, 3))
    tf.compat.v1.train.create_global_step()
113
114
115
116
117
118
119
120
121
122
123
    with slim.arg_scope(nasnet.nasnet_large_arg_scope()):
      net, end_points = nasnet.build_nasnet_large(inputs, num_classes)
    self.assertFalse('AuxLogits' in end_points)
    self.assertFalse('Predictions' in end_points)
    self.assertTrue(net.op.name.startswith('final_layer/Mean'))
    self.assertListEqual(net.get_shape().as_list(), [batch_size, 4032])

  def testAllEndPointsShapesCifarModel(self):
    batch_size = 5
    height, width = 32, 32
    num_classes = 10
124
125
    inputs = tf.random.uniform((batch_size, height, width, 3))
    tf.compat.v1.train.create_global_step()
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    with slim.arg_scope(nasnet.nasnet_cifar_arg_scope()):
      _, end_points = nasnet.build_nasnet_cifar(inputs, num_classes)
    endpoints_shapes = {'Stem': [batch_size, 32, 32, 96],
                        'Cell_0': [batch_size, 32, 32, 192],
                        'Cell_1': [batch_size, 32, 32, 192],
                        'Cell_2': [batch_size, 32, 32, 192],
                        'Cell_3': [batch_size, 32, 32, 192],
                        'Cell_4': [batch_size, 32, 32, 192],
                        'Cell_5': [batch_size, 32, 32, 192],
                        'Cell_6': [batch_size, 16, 16, 384],
                        'Cell_7': [batch_size, 16, 16, 384],
                        'Cell_8': [batch_size, 16, 16, 384],
                        'Cell_9': [batch_size, 16, 16, 384],
                        'Cell_10': [batch_size, 16, 16, 384],
                        'Cell_11': [batch_size, 16, 16, 384],
                        'Cell_12': [batch_size, 8, 8, 768],
                        'Cell_13': [batch_size, 8, 8, 768],
                        'Cell_14': [batch_size, 8, 8, 768],
                        'Cell_15': [batch_size, 8, 8, 768],
                        'Cell_16': [batch_size, 8, 8, 768],
                        'Cell_17': [batch_size, 8, 8, 768],
                        'Reduction_Cell_0': [batch_size, 16, 16, 256],
                        'Reduction_Cell_1': [batch_size, 8, 8, 512],
                        'global_pool': [batch_size, 768],
                        # Logits and predictions
                        'AuxLogits': [batch_size, num_classes],
                        'Logits': [batch_size, num_classes],
                        'Predictions': [batch_size, num_classes]}
    self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
    for endpoint_name in endpoints_shapes:
156
      tf.compat.v1.logging.info('Endpoint name: {}'.format(endpoint_name))
157
158
159
160
161
      expected_shape = endpoints_shapes[endpoint_name]
      self.assertTrue(endpoint_name in end_points)
      self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
                           expected_shape)

pkulzc's avatar
pkulzc committed
162
163
164
165
166
  def testNoAuxHeadCifarModel(self):
    batch_size = 5
    height, width = 32, 32
    num_classes = 10
    for use_aux_head in (True, False):
167
168
169
      tf.compat.v1.reset_default_graph()
      inputs = tf.random.uniform((batch_size, height, width, 3))
      tf.compat.v1.train.create_global_step()
170
171
      config = nasnet.cifar_config()
      config.set_hparam('use_aux_head', int(use_aux_head))
pkulzc's avatar
pkulzc committed
172
173
      with slim.arg_scope(nasnet.nasnet_cifar_arg_scope()):
        _, end_points = nasnet.build_nasnet_cifar(inputs, num_classes,
174
                                                  config=config)
pkulzc's avatar
pkulzc committed
175
176
      self.assertEqual('AuxLogits' in end_points, use_aux_head)

177
178
179
180
  def testAllEndPointsShapesMobileModel(self):
    batch_size = 5
    height, width = 224, 224
    num_classes = 1000
181
182
    inputs = tf.random.uniform((batch_size, height, width, 3))
    tf.compat.v1.train.create_global_step()
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
      _, end_points = nasnet.build_nasnet_mobile(inputs, num_classes)
    endpoints_shapes = {'Stem': [batch_size, 28, 28, 88],
                        'Cell_0': [batch_size, 28, 28, 264],
                        'Cell_1': [batch_size, 28, 28, 264],
                        'Cell_2': [batch_size, 28, 28, 264],
                        'Cell_3': [batch_size, 28, 28, 264],
                        'Cell_4': [batch_size, 14, 14, 528],
                        'Cell_5': [batch_size, 14, 14, 528],
                        'Cell_6': [batch_size, 14, 14, 528],
                        'Cell_7': [batch_size, 14, 14, 528],
                        'Cell_8': [batch_size, 7, 7, 1056],
                        'Cell_9': [batch_size, 7, 7, 1056],
                        'Cell_10': [batch_size, 7, 7, 1056],
                        'Cell_11': [batch_size, 7, 7, 1056],
                        'Reduction_Cell_0': [batch_size, 14, 14, 352],
                        'Reduction_Cell_1': [batch_size, 7, 7, 704],
                        'global_pool': [batch_size, 1056],
                        # Logits and predictions
                        'AuxLogits': [batch_size, num_classes],
                        'Logits': [batch_size, num_classes],
                        'Predictions': [batch_size, num_classes]}
    self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
    for endpoint_name in endpoints_shapes:
207
      tf.compat.v1.logging.info('Endpoint name: {}'.format(endpoint_name))
208
209
210
211
212
      expected_shape = endpoints_shapes[endpoint_name]
      self.assertTrue(endpoint_name in end_points)
      self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
                           expected_shape)

pkulzc's avatar
pkulzc committed
213
214
215
216
217
  def testNoAuxHeadMobileModel(self):
    batch_size = 5
    height, width = 224, 224
    num_classes = 1000
    for use_aux_head in (True, False):
218
219
220
      tf.compat.v1.reset_default_graph()
      inputs = tf.random.uniform((batch_size, height, width, 3))
      tf.compat.v1.train.create_global_step()
221
222
      config = nasnet.mobile_imagenet_config()
      config.set_hparam('use_aux_head', int(use_aux_head))
pkulzc's avatar
pkulzc committed
223
224
      with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
        _, end_points = nasnet.build_nasnet_mobile(inputs, num_classes,
225
                                                   config=config)
pkulzc's avatar
pkulzc committed
226
227
      self.assertEqual('AuxLogits' in end_points, use_aux_head)

228
229
230
231
  def testAllEndPointsShapesLargeModel(self):
    batch_size = 5
    height, width = 331, 331
    num_classes = 1000
232
233
    inputs = tf.random.uniform((batch_size, height, width, 3))
    tf.compat.v1.train.create_global_step()
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    with slim.arg_scope(nasnet.nasnet_large_arg_scope()):
      _, end_points = nasnet.build_nasnet_large(inputs, num_classes)
    endpoints_shapes = {'Stem': [batch_size, 42, 42, 336],
                        'Cell_0': [batch_size, 42, 42, 1008],
                        'Cell_1': [batch_size, 42, 42, 1008],
                        'Cell_2': [batch_size, 42, 42, 1008],
                        'Cell_3': [batch_size, 42, 42, 1008],
                        'Cell_4': [batch_size, 42, 42, 1008],
                        'Cell_5': [batch_size, 42, 42, 1008],
                        'Cell_6': [batch_size, 21, 21, 2016],
                        'Cell_7': [batch_size, 21, 21, 2016],
                        'Cell_8': [batch_size, 21, 21, 2016],
                        'Cell_9': [batch_size, 21, 21, 2016],
                        'Cell_10': [batch_size, 21, 21, 2016],
                        'Cell_11': [batch_size, 21, 21, 2016],
                        'Cell_12': [batch_size, 11, 11, 4032],
                        'Cell_13': [batch_size, 11, 11, 4032],
                        'Cell_14': [batch_size, 11, 11, 4032],
                        'Cell_15': [batch_size, 11, 11, 4032],
                        'Cell_16': [batch_size, 11, 11, 4032],
                        'Cell_17': [batch_size, 11, 11, 4032],
                        'Reduction_Cell_0': [batch_size, 21, 21, 1344],
                        'Reduction_Cell_1': [batch_size, 11, 11, 2688],
                        'global_pool': [batch_size, 4032],
                        # Logits and predictions
                        'AuxLogits': [batch_size, num_classes],
                        'Logits': [batch_size, num_classes],
                        'Predictions': [batch_size, num_classes]}
    self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
    for endpoint_name in endpoints_shapes:
264
      tf.compat.v1.logging.info('Endpoint name: {}'.format(endpoint_name))
265
266
267
268
269
      expected_shape = endpoints_shapes[endpoint_name]
      self.assertTrue(endpoint_name in end_points)
      self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
                           expected_shape)

pkulzc's avatar
pkulzc committed
270
271
272
273
274
  def testNoAuxHeadLargeModel(self):
    batch_size = 5
    height, width = 331, 331
    num_classes = 1000
    for use_aux_head in (True, False):
275
276
277
      tf.compat.v1.reset_default_graph()
      inputs = tf.random.uniform((batch_size, height, width, 3))
      tf.compat.v1.train.create_global_step()
278
279
      config = nasnet.large_imagenet_config()
      config.set_hparam('use_aux_head', int(use_aux_head))
pkulzc's avatar
pkulzc committed
280
281
      with slim.arg_scope(nasnet.nasnet_large_arg_scope()):
        _, end_points = nasnet.build_nasnet_large(inputs, num_classes,
282
                                                  config=config)
pkulzc's avatar
pkulzc committed
283
284
      self.assertEqual('AuxLogits' in end_points, use_aux_head)

285
286
287
288
  def testVariablesSetDeviceMobileModel(self):
    batch_size = 5
    height, width = 224, 224
    num_classes = 1000
289
290
    inputs = tf.random.uniform((batch_size, height, width, 3))
    tf.compat.v1.train.create_global_step()
291
    # Force all Variables to reside on the device.
292
    with tf.compat.v1.variable_scope('on_cpu'), tf.device('/cpu:0'):
293
294
      with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
        nasnet.build_nasnet_mobile(inputs, num_classes)
295
    with tf.compat.v1.variable_scope('on_gpu'), tf.device('/gpu:0'):
296
297
      with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
        nasnet.build_nasnet_mobile(inputs, num_classes)
298
299
    for v in tf.compat.v1.get_collection(
        tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope='on_cpu'):
300
      self.assertDeviceEqual(v.device, '/cpu:0')
301
302
    for v in tf.compat.v1.get_collection(
        tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope='on_gpu'):
303
304
305
306
307
308
309
      self.assertDeviceEqual(v.device, '/gpu:0')

  def testUnknownBatchSizeMobileModel(self):
    batch_size = 1
    height, width = 224, 224
    num_classes = 1000
    with self.test_session() as sess:
310
      inputs = tf.compat.v1.placeholder(tf.float32, (None, height, width, 3))
311
312
313
314
      with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
        logits, _ = nasnet.build_nasnet_mobile(inputs, num_classes)
      self.assertListEqual(logits.get_shape().as_list(),
                           [None, num_classes])
315
316
      images = tf.random.uniform((batch_size, height, width, 3))
      sess.run(tf.compat.v1.global_variables_initializer())
317
318
319
320
321
322
323
324
      output = sess.run(logits, {inputs: images.eval()})
      self.assertEquals(output.shape, (batch_size, num_classes))

  def testEvaluationMobileModel(self):
    batch_size = 2
    height, width = 224, 224
    num_classes = 1000
    with self.test_session() as sess:
325
      eval_inputs = tf.random.uniform((batch_size, height, width, 3))
326
327
328
329
      with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
        logits, _ = nasnet.build_nasnet_mobile(eval_inputs,
                                               num_classes,
                                               is_training=False)
330
331
      predictions = tf.argmax(input=logits, axis=1)
      sess.run(tf.compat.v1.global_variables_initializer())
332
333
334
      output = sess.run(predictions)
      self.assertEquals(output.shape, (batch_size,))

335
336
337
338
  def testOverrideHParamsCifarModel(self):
    batch_size = 5
    height, width = 32, 32
    num_classes = 10
339
340
    inputs = tf.random.uniform((batch_size, height, width, 3))
    tf.compat.v1.train.create_global_step()
341
342
343
344
345
346
347
348
349
350
351
352
    config = nasnet.cifar_config()
    config.set_hparam('data_format', 'NCHW')
    with slim.arg_scope(nasnet.nasnet_cifar_arg_scope()):
      _, end_points = nasnet.build_nasnet_cifar(
          inputs, num_classes, config=config)
    self.assertListEqual(
        end_points['Stem'].shape.as_list(), [batch_size, 96, 32, 32])

  def testOverrideHParamsMobileModel(self):
    batch_size = 5
    height, width = 224, 224
    num_classes = 1000
353
354
    inputs = tf.random.uniform((batch_size, height, width, 3))
    tf.compat.v1.train.create_global_step()
355
356
357
358
359
360
361
362
363
364
365
366
    config = nasnet.mobile_imagenet_config()
    config.set_hparam('data_format', 'NCHW')
    with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
      _, end_points = nasnet.build_nasnet_mobile(
          inputs, num_classes, config=config)
    self.assertListEqual(
        end_points['Stem'].shape.as_list(), [batch_size, 88, 28, 28])

  def testOverrideHParamsLargeModel(self):
    batch_size = 5
    height, width = 331, 331
    num_classes = 1000
367
368
    inputs = tf.random.uniform((batch_size, height, width, 3))
    tf.compat.v1.train.create_global_step()
369
370
371
372
373
374
375
376
    config = nasnet.large_imagenet_config()
    config.set_hparam('data_format', 'NCHW')
    with slim.arg_scope(nasnet.nasnet_large_arg_scope()):
      _, end_points = nasnet.build_nasnet_large(
          inputs, num_classes, config=config)
    self.assertListEqual(
        end_points['Stem'].shape.as_list(), [batch_size, 336, 42, 42])

377
378
379
380
  def testCurrentStepCifarModel(self):
    batch_size = 5
    height, width = 32, 32
    num_classes = 10
381
382
    inputs = tf.random.uniform((batch_size, height, width, 3))
    global_step = tf.compat.v1.train.create_global_step()
383
384
385
386
387
388
389
390
391
392
393
394
    with slim.arg_scope(nasnet.nasnet_cifar_arg_scope()):
      logits, end_points = nasnet.build_nasnet_cifar(inputs,
                                                     num_classes,
                                                     current_step=global_step)
    auxlogits = end_points['AuxLogits']
    predictions = end_points['Predictions']
    self.assertListEqual(auxlogits.get_shape().as_list(),
                         [batch_size, num_classes])
    self.assertListEqual(logits.get_shape().as_list(),
                         [batch_size, num_classes])
    self.assertListEqual(predictions.get_shape().as_list(),
                         [batch_size, num_classes])
395

396
397
398
399
400
  def testUseBoundedAcitvationCifarModel(self):
    batch_size = 1
    height, width = 32, 32
    num_classes = 10
    for use_bounded_activation in (True, False):
401
402
      tf.compat.v1.reset_default_graph()
      inputs = tf.random.uniform((batch_size, height, width, 3))
403
404
405
406
407
      config = nasnet.cifar_config()
      config.set_hparam('use_bounded_activation', use_bounded_activation)
      with slim.arg_scope(nasnet.nasnet_cifar_arg_scope()):
        _, _ = nasnet.build_nasnet_cifar(
            inputs, num_classes, config=config)
408
      for node in tf.compat.v1.get_default_graph().as_graph_def().node:
409
410
411
        if node.op.startswith('Relu'):
          self.assertEqual(node.op == 'Relu6', use_bounded_activation)

412
413
if __name__ == '__main__':
  tf.test.main()