nasnet.py 19.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains the definition for the NASNet classification networks.

Paper: https://arxiv.org/abs/1707.07012
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

23
import copy
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import tensorflow as tf

from nets.nasnet import nasnet_utils

arg_scope = tf.contrib.framework.arg_scope
slim = tf.contrib.slim


# Notes for training NASNet Cifar Model
# -------------------------------------
# batch_size: 32
# learning rate: 0.025
# cosine (single period) learning rate decay
# auxiliary head loss weighting: 0.4
# clip global norm of all gradients by 5
39
def cifar_config():
40
41
  return tf.contrib.training.HParams(
      stem_multiplier=3.0,
42
      drop_path_keep_prob=0.6,
43
      num_cells=18,
44
      use_aux_head=1,
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
      num_conv_filters=32,
      dense_dropout_keep_prob=1.0,
      filter_scaling_rate=2.0,
      num_reduction_layers=2,
      data_format='NHWC',
      skip_reduction_layer_input=0,
      # 600 epochs with a batch size of 32
      # This is used for the drop path probabilities since it needs to increase
      # the drop out probability over the course of training.
      total_training_steps=937500,
  )


# Notes for training large NASNet model on ImageNet
# -------------------------------------
# batch size (per replica): 16
# learning rate: 0.015 * 100
# learning rate decay factor: 0.97
# num epochs per decay: 2.4
# sync sgd with 100 replicas
# auxiliary head loss weighting: 0.4
# label smoothing: 0.1
# clip global norm of all gradients by 10
68
def large_imagenet_config():
69
70
71
72
73
74
  return tf.contrib.training.HParams(
      stem_multiplier=3.0,
      dense_dropout_keep_prob=0.5,
      num_cells=18,
      filter_scaling_rate=2.0,
      num_conv_filters=168,
75
76
      drop_path_keep_prob=0.7,
      use_aux_head=1,
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
      num_reduction_layers=2,
      data_format='NHWC',
      skip_reduction_layer_input=1,
      total_training_steps=250000,
  )


# Notes for training the mobile NASNet ImageNet model
# -------------------------------------
# batch size (per replica): 32
# learning rate: 0.04 * 50
# learning rate scaling factor: 0.97
# num epochs per decay: 2.4
# sync sgd with 50 replicas
# auxiliary head weighting: 0.4
# label smoothing: 0.1
# clip global norm of all gradients by 10
94
def mobile_imagenet_config():
95
96
97
98
99
100
101
  return tf.contrib.training.HParams(
      stem_multiplier=1.0,
      dense_dropout_keep_prob=0.5,
      num_cells=12,
      filter_scaling_rate=2.0,
      drop_path_keep_prob=1.0,
      num_conv_filters=44,
102
      use_aux_head=1,
103
104
105
106
107
108
109
      num_reduction_layers=2,
      data_format='NHWC',
      skip_reduction_layer_input=0,
      total_training_steps=250000,
  )


110
111
112
113
114
115
def _update_hparams(hparams, is_training):
  """Update hparams for given is_training option."""
  if not is_training:
    hparams.set_hparam('drop_path_keep_prob', 1.0)


116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def nasnet_cifar_arg_scope(weight_decay=5e-4,
                           batch_norm_decay=0.9,
                           batch_norm_epsilon=1e-5):
  """Defines the default arg scope for the NASNet-A Cifar model.

  Args:
    weight_decay: The weight decay to use for regularizing the model.
    batch_norm_decay: Decay for batch norm moving average.
    batch_norm_epsilon: Small float added to variance to avoid dividing by zero
      in batch norm.

  Returns:
    An `arg_scope` to use for the NASNet Cifar Model.
  """
  batch_norm_params = {
      # Decay for the moving averages.
      'decay': batch_norm_decay,
      # epsilon to prevent 0s in variance.
      'epsilon': batch_norm_epsilon,
      'scale': True,
      'fused': True,
  }
  weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
  weights_initializer = tf.contrib.layers.variance_scaling_initializer(
      mode='FAN_OUT')
  with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d],
                 weights_regularizer=weights_regularizer,
                 weights_initializer=weights_initializer):
    with arg_scope([slim.fully_connected],
                   activation_fn=None, scope='FC'):
      with arg_scope([slim.conv2d, slim.separable_conv2d],
                     activation_fn=None, biases_initializer=None):
        with arg_scope([slim.batch_norm], **batch_norm_params) as sc:
          return sc


def nasnet_mobile_arg_scope(weight_decay=4e-5,
                            batch_norm_decay=0.9997,
                            batch_norm_epsilon=1e-3):
  """Defines the default arg scope for the NASNet-A Mobile ImageNet model.

  Args:
    weight_decay: The weight decay to use for regularizing the model.
    batch_norm_decay: Decay for batch norm moving average.
    batch_norm_epsilon: Small float added to variance to avoid dividing by zero
      in batch norm.

  Returns:
    An `arg_scope` to use for the NASNet Mobile Model.
  """
  batch_norm_params = {
      # Decay for the moving averages.
      'decay': batch_norm_decay,
      # epsilon to prevent 0s in variance.
      'epsilon': batch_norm_epsilon,
      'scale': True,
      'fused': True,
  }
  weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
  weights_initializer = tf.contrib.layers.variance_scaling_initializer(
      mode='FAN_OUT')
  with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d],
                 weights_regularizer=weights_regularizer,
                 weights_initializer=weights_initializer):
    with arg_scope([slim.fully_connected],
                   activation_fn=None, scope='FC'):
      with arg_scope([slim.conv2d, slim.separable_conv2d],
                     activation_fn=None, biases_initializer=None):
        with arg_scope([slim.batch_norm], **batch_norm_params) as sc:
          return sc


def nasnet_large_arg_scope(weight_decay=5e-5,
                           batch_norm_decay=0.9997,
                           batch_norm_epsilon=1e-3):
  """Defines the default arg scope for the NASNet-A Large ImageNet model.

  Args:
    weight_decay: The weight decay to use for regularizing the model.
    batch_norm_decay: Decay for batch norm moving average.
    batch_norm_epsilon: Small float added to variance to avoid dividing by zero
      in batch norm.

  Returns:
    An `arg_scope` to use for the NASNet Large Model.
  """
  batch_norm_params = {
      # Decay for the moving averages.
      'decay': batch_norm_decay,
      # epsilon to prevent 0s in variance.
      'epsilon': batch_norm_epsilon,
      'scale': True,
      'fused': True,
  }
  weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
  weights_initializer = tf.contrib.layers.variance_scaling_initializer(
      mode='FAN_OUT')
  with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d],
                 weights_regularizer=weights_regularizer,
                 weights_initializer=weights_initializer):
    with arg_scope([slim.fully_connected],
                   activation_fn=None, scope='FC'):
      with arg_scope([slim.conv2d, slim.separable_conv2d],
                     activation_fn=None, biases_initializer=None):
        with arg_scope([slim.batch_norm], **batch_norm_params) as sc:
          return sc


def _build_aux_head(net, end_points, num_classes, hparams, scope):
  """Auxiliary head used for all models across all datasets."""
  with tf.variable_scope(scope):
    aux_logits = tf.identity(net)
    with tf.variable_scope('aux_logits'):
      aux_logits = slim.avg_pool2d(
          aux_logits, [5, 5], stride=3, padding='VALID')
      aux_logits = slim.conv2d(aux_logits, 128, [1, 1], scope='proj')
      aux_logits = slim.batch_norm(aux_logits, scope='aux_bn0')
      aux_logits = tf.nn.relu(aux_logits)
      # Shape of feature map before the final layer.
      shape = aux_logits.shape
      if hparams.data_format == 'NHWC':
        shape = shape[1:3]
      else:
        shape = shape[2:4]
      aux_logits = slim.conv2d(aux_logits, 768, shape, padding='VALID')
      aux_logits = slim.batch_norm(aux_logits, scope='aux_bn1')
      aux_logits = tf.nn.relu(aux_logits)
      aux_logits = tf.contrib.layers.flatten(aux_logits)
      aux_logits = slim.fully_connected(aux_logits, num_classes)
      end_points['AuxLogits'] = aux_logits


248
def _imagenet_stem(inputs, hparams, stem_cell, current_step=None):
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
  """Stem used for models trained on ImageNet."""
  num_stem_cells = 2

  # 149 x 149 x 32
  num_stem_filters = int(32 * hparams.stem_multiplier)
  net = slim.conv2d(
      inputs, num_stem_filters, [3, 3], stride=2, scope='conv0',
      padding='VALID')
  net = slim.batch_norm(net, scope='conv0_bn')

  # Run the reduction cells
  cell_outputs = [None, net]
  filter_scaling = 1.0 / (hparams.filter_scaling_rate**num_stem_cells)
  for cell_num in range(num_stem_cells):
    net = stem_cell(
        net,
        scope='cell_stem_{}'.format(cell_num),
        filter_scaling=filter_scaling,
        stride=2,
        prev_layer=cell_outputs[-2],
269
270
        cell_num=cell_num,
        current_step=current_step)
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    cell_outputs.append(net)
    filter_scaling *= hparams.filter_scaling_rate
  return net, cell_outputs


def _cifar_stem(inputs, hparams):
  """Stem used for models trained on Cifar."""
  num_stem_filters = int(hparams.num_conv_filters * hparams.stem_multiplier)
  net = slim.conv2d(
      inputs,
      num_stem_filters,
      3,
      scope='l1_stem_3x3')
  net = slim.batch_norm(net, scope='l1_stem_bn')
  return net, [None, net]


288
289
def build_nasnet_cifar(images, num_classes,
                       is_training=True,
290
291
                       config=None,
                       current_step=None):
292
  """Build NASNet model for the Cifar Dataset."""
293
294
  hparams = cifar_config() if config is None else copy.deepcopy(config)
  _update_hparams(hparams, is_training)
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330

  if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
    tf.logging.info('A GPU is available on the machine, consider using NCHW '
                    'data format for increased speed on GPU.')

  if hparams.data_format == 'NCHW':
    images = tf.transpose(images, [0, 3, 1, 2])

  # Calculate the total number of cells in the network
  # Add 2 for the reduction cells
  total_num_cells = hparams.num_cells + 2

  normal_cell = nasnet_utils.NasNetANormalCell(
      hparams.num_conv_filters, hparams.drop_path_keep_prob,
      total_num_cells, hparams.total_training_steps)
  reduction_cell = nasnet_utils.NasNetAReductionCell(
      hparams.num_conv_filters, hparams.drop_path_keep_prob,
      total_num_cells, hparams.total_training_steps)
  with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
                 is_training=is_training):
    with arg_scope([slim.avg_pool2d,
                    slim.max_pool2d,
                    slim.conv2d,
                    slim.batch_norm,
                    slim.separable_conv2d,
                    nasnet_utils.factorized_reduction,
                    nasnet_utils.global_avg_pool,
                    nasnet_utils.get_channel_index,
                    nasnet_utils.get_channel_dim],
                   data_format=hparams.data_format):
      return _build_nasnet_base(images,
                                normal_cell=normal_cell,
                                reduction_cell=reduction_cell,
                                num_classes=num_classes,
                                hparams=hparams,
                                is_training=is_training,
331
332
                                stem_type='cifar',
                                current_step=current_step)
333
334
335
336
build_nasnet_cifar.default_image_size = 32


def build_nasnet_mobile(images, num_classes,
337
                        is_training=True,
pkulzc's avatar
pkulzc committed
338
                        final_endpoint=None,
339
340
                        config=None,
                        current_step=None):
341
  """Build NASNet Mobile model for the ImageNet Dataset."""
342
343
344
  hparams = (mobile_imagenet_config() if config is None
             else copy.deepcopy(config))
  _update_hparams(hparams, is_training)
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364

  if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
    tf.logging.info('A GPU is available on the machine, consider using NCHW '
                    'data format for increased speed on GPU.')

  if hparams.data_format == 'NCHW':
    images = tf.transpose(images, [0, 3, 1, 2])

  # Calculate the total number of cells in the network
  # Add 2 for the reduction cells
  total_num_cells = hparams.num_cells + 2
  # If ImageNet, then add an additional two for the stem cells
  total_num_cells += 2

  normal_cell = nasnet_utils.NasNetANormalCell(
      hparams.num_conv_filters, hparams.drop_path_keep_prob,
      total_num_cells, hparams.total_training_steps)
  reduction_cell = nasnet_utils.NasNetAReductionCell(
      hparams.num_conv_filters, hparams.drop_path_keep_prob,
      total_num_cells, hparams.total_training_steps)
365
  with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
366
                 is_training=is_training):
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    with arg_scope([slim.avg_pool2d,
                    slim.max_pool2d,
                    slim.conv2d,
                    slim.batch_norm,
                    slim.separable_conv2d,
                    nasnet_utils.factorized_reduction,
                    nasnet_utils.global_avg_pool,
                    nasnet_utils.get_channel_index,
                    nasnet_utils.get_channel_dim],
                   data_format=hparams.data_format):
      return _build_nasnet_base(images,
                                normal_cell=normal_cell,
                                reduction_cell=reduction_cell,
                                num_classes=num_classes,
                                hparams=hparams,
                                is_training=is_training,
                                stem_type='imagenet',
384
385
                                final_endpoint=final_endpoint,
                                current_step=current_step)
386
387
388
389
build_nasnet_mobile.default_image_size = 224


def build_nasnet_large(images, num_classes,
390
                       is_training=True,
pkulzc's avatar
pkulzc committed
391
                       final_endpoint=None,
392
393
                       config=None,
                       current_step=None):
394
  """Build NASNet Large model for the ImageNet Dataset."""
395
396
397
  hparams = (large_imagenet_config() if config is None
             else copy.deepcopy(config))
  _update_hparams(hparams, is_training)
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417

  if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
    tf.logging.info('A GPU is available on the machine, consider using NCHW '
                    'data format for increased speed on GPU.')

  if hparams.data_format == 'NCHW':
    images = tf.transpose(images, [0, 3, 1, 2])

  # Calculate the total number of cells in the network
  # Add 2 for the reduction cells
  total_num_cells = hparams.num_cells + 2
  # If ImageNet, then add an additional two for the stem cells
  total_num_cells += 2

  normal_cell = nasnet_utils.NasNetANormalCell(
      hparams.num_conv_filters, hparams.drop_path_keep_prob,
      total_num_cells, hparams.total_training_steps)
  reduction_cell = nasnet_utils.NasNetAReductionCell(
      hparams.num_conv_filters, hparams.drop_path_keep_prob,
      total_num_cells, hparams.total_training_steps)
418
  with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
419
                 is_training=is_training):
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
    with arg_scope([slim.avg_pool2d,
                    slim.max_pool2d,
                    slim.conv2d,
                    slim.batch_norm,
                    slim.separable_conv2d,
                    nasnet_utils.factorized_reduction,
                    nasnet_utils.global_avg_pool,
                    nasnet_utils.get_channel_index,
                    nasnet_utils.get_channel_dim],
                   data_format=hparams.data_format):
      return _build_nasnet_base(images,
                                normal_cell=normal_cell,
                                reduction_cell=reduction_cell,
                                num_classes=num_classes,
                                hparams=hparams,
                                is_training=is_training,
                                stem_type='imagenet',
437
438
                                final_endpoint=final_endpoint,
                                current_step=current_step)
439
440
441
442
443
444
445
446
447
448
build_nasnet_large.default_image_size = 331


def _build_nasnet_base(images,
                       normal_cell,
                       reduction_cell,
                       num_classes,
                       hparams,
                       is_training,
                       stem_type,
449
450
                       final_endpoint=None,
                       current_step=None):
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
  """Constructs a NASNet image model."""

  end_points = {}
  def add_and_check_endpoint(endpoint_name, net):
    end_points[endpoint_name] = net
    return final_endpoint and (endpoint_name == final_endpoint)

  # Find where to place the reduction cells or stride normal cells
  reduction_indices = nasnet_utils.calc_reduction_layers(
      hparams.num_cells, hparams.num_reduction_layers)
  stem_cell = reduction_cell

  if stem_type == 'imagenet':
    stem = lambda: _imagenet_stem(images, hparams, stem_cell)
  elif stem_type == 'cifar':
    stem = lambda: _cifar_stem(images, hparams)
  else:
    raise ValueError('Unknown stem_type: ', stem_type)
  net, cell_outputs = stem()
  if add_and_check_endpoint('Stem', net): return net, end_points

  # Setup for building in the auxiliary head.
  aux_head_cell_idxes = []
  if len(reduction_indices) >= 2:
    aux_head_cell_idxes.append(reduction_indices[1] - 1)

  # Run the cells
  filter_scaling = 1.0
  # true_cell_num accounts for the stem cells
  true_cell_num = 2 if stem_type == 'imagenet' else 0
  for cell_num in range(hparams.num_cells):
    stride = 1
    if hparams.skip_reduction_layer_input:
      prev_layer = cell_outputs[-2]
    if cell_num in reduction_indices:
      filter_scaling *= hparams.filter_scaling_rate
      net = reduction_cell(
          net,
          scope='reduction_cell_{}'.format(reduction_indices.index(cell_num)),
          filter_scaling=filter_scaling,
          stride=2,
          prev_layer=cell_outputs[-2],
493
494
          cell_num=true_cell_num,
          current_step=current_step)
495
496
497
498
499
500
501
502
503
504
505
506
507
      if add_and_check_endpoint(
          'Reduction_Cell_{}'.format(reduction_indices.index(cell_num)), net):
        return net, end_points
      true_cell_num += 1
      cell_outputs.append(net)
    if not hparams.skip_reduction_layer_input:
      prev_layer = cell_outputs[-2]
    net = normal_cell(
        net,
        scope='cell_{}'.format(cell_num),
        filter_scaling=filter_scaling,
        stride=stride,
        prev_layer=prev_layer,
508
509
        cell_num=true_cell_num,
        current_step=current_step)
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524

    if add_and_check_endpoint('Cell_{}'.format(cell_num), net):
      return net, end_points
    true_cell_num += 1
    if (hparams.use_aux_head and cell_num in aux_head_cell_idxes and
        num_classes and is_training):
      aux_net = tf.nn.relu(net)
      _build_aux_head(aux_net, end_points, num_classes, hparams,
                      scope='aux_{}'.format(cell_num))
    cell_outputs.append(net)

  # Final softmax layer
  with tf.variable_scope('final_layer'):
    net = tf.nn.relu(net)
    net = nasnet_utils.global_avg_pool(net)
pkulzc's avatar
pkulzc committed
525
    if add_and_check_endpoint('global_pool', net) or not num_classes:
526
527
528
529
530
531
532
533
534
535
536
      return net, end_points
    net = slim.dropout(net, hparams.dense_dropout_keep_prob, scope='dropout')
    logits = slim.fully_connected(net, num_classes)

    if add_and_check_endpoint('Logits', logits):
      return net, end_points

    predictions = tf.nn.softmax(logits, name='predictions')
    if add_and_check_endpoint('Predictions', predictions):
      return net, end_points
  return logits, end_points