Commit cf8f014b authored by Jon Shlens's avatar Jon Shlens
Browse files

Fix nasnet image classification and object detection

Fix nasnet image classification and object detection by moving the
option to turn ON or OFF batch norm training into it's own arg_scope
used only by detection
parent b3f04bca
...@@ -30,6 +30,23 @@ arg_scope = tf.contrib.framework.arg_scope ...@@ -30,6 +30,23 @@ arg_scope = tf.contrib.framework.arg_scope
slim = tf.contrib.slim slim = tf.contrib.slim
def nasnet_large_arg_scope_for_detection(is_batch_norm_training=False):
"""Defines the default arg scope for the NASNet-A Large for object detection.
This provides a small edit to switch batch norm training on and off.
Args:
is_batch_norm_training: Boolean indicating whether to train with batch norm.
Returns:
An `arg_scope` to use for the NASNet Large Model.
"""
imagenet_scope = nasnet.nasnet_large_arg_scope()
with arg_scope(imagenet_scope):
with arg_scope([slim.batch_norm], is_training=is_batch_norm_training) as sc:
return sc
# Note: This is largely a copy of _build_nasnet_base inside nasnet.py but # Note: This is largely a copy of _build_nasnet_base inside nasnet.py but
# with special edits to remove instantiation of the stem and the special # with special edits to remove instantiation of the stem and the special
# ability to receive as input a pair of hidden states. # ability to receive as input a pair of hidden states.
...@@ -163,11 +180,11 @@ class FasterRCNNNASFeatureExtractor( ...@@ -163,11 +180,11 @@ class FasterRCNNNASFeatureExtractor(
raise ValueError('`preprocessed_inputs` must be 4 dimensional, got a ' raise ValueError('`preprocessed_inputs` must be 4 dimensional, got a '
'tensor of shape %s' % preprocessed_inputs.get_shape()) 'tensor of shape %s' % preprocessed_inputs.get_shape())
with slim.arg_scope(nasnet.nasnet_large_arg_scope()): with slim.arg_scope(nasnet_large_arg_scope_for_detection(
is_batch_norm_training=self._train_batch_norm)):
_, end_points = nasnet.build_nasnet_large( _, end_points = nasnet.build_nasnet_large(
preprocessed_inputs, num_classes=None, preprocessed_inputs, num_classes=None,
is_training=self._is_training, is_training=self._is_training,
is_batchnorm_training=self._train_batch_norm,
final_endpoint='Cell_11') final_endpoint='Cell_11')
# Note that both 'Cell_10' and 'Cell_11' have equal depth = 2016. # Note that both 'Cell_10' and 'Cell_11' have equal depth = 2016.
......
...@@ -324,7 +324,7 @@ build_nasnet_cifar.default_image_size = 32 ...@@ -324,7 +324,7 @@ build_nasnet_cifar.default_image_size = 32
def build_nasnet_mobile(images, num_classes, def build_nasnet_mobile(images, num_classes,
is_training=True, is_batchnorm_training=True, is_training=True,
final_endpoint=None): final_endpoint=None):
"""Build NASNet Mobile model for the ImageNet Dataset.""" """Build NASNet Mobile model for the ImageNet Dataset."""
hparams = _mobile_imagenet_config() hparams = _mobile_imagenet_config()
...@@ -348,32 +348,31 @@ def build_nasnet_mobile(images, num_classes, ...@@ -348,32 +348,31 @@ def build_nasnet_mobile(images, num_classes,
reduction_cell = nasnet_utils.NasNetAReductionCell( reduction_cell = nasnet_utils.NasNetAReductionCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob, hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps) total_num_cells, hparams.total_training_steps)
with arg_scope([slim.dropout, nasnet_utils.drop_path], with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
is_training=is_training): is_training=is_training):
with arg_scope([slim.batch_norm], is_training=is_batchnorm_training): with arg_scope([slim.avg_pool2d,
with arg_scope([slim.avg_pool2d, slim.max_pool2d,
slim.max_pool2d, slim.conv2d,
slim.conv2d, slim.batch_norm,
slim.batch_norm, slim.separable_conv2d,
slim.separable_conv2d, nasnet_utils.factorized_reduction,
nasnet_utils.factorized_reduction, nasnet_utils.global_avg_pool,
nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index,
nasnet_utils.get_channel_index, nasnet_utils.get_channel_dim],
nasnet_utils.get_channel_dim], data_format=hparams.data_format):
data_format=hparams.data_format): return _build_nasnet_base(images,
return _build_nasnet_base(images, normal_cell=normal_cell,
normal_cell=normal_cell, reduction_cell=reduction_cell,
reduction_cell=reduction_cell, num_classes=num_classes,
num_classes=num_classes, hparams=hparams,
hparams=hparams, is_training=is_training,
is_training=is_training, stem_type='imagenet',
stem_type='imagenet', final_endpoint=final_endpoint)
final_endpoint=final_endpoint)
build_nasnet_mobile.default_image_size = 224 build_nasnet_mobile.default_image_size = 224
def build_nasnet_large(images, num_classes, def build_nasnet_large(images, num_classes,
is_training=True, is_batchnorm_training=True, is_training=True,
final_endpoint=None): final_endpoint=None):
"""Build NASNet Large model for the ImageNet Dataset.""" """Build NASNet Large model for the ImageNet Dataset."""
hparams = _large_imagenet_config(is_training=is_training) hparams = _large_imagenet_config(is_training=is_training)
...@@ -397,27 +396,26 @@ def build_nasnet_large(images, num_classes, ...@@ -397,27 +396,26 @@ def build_nasnet_large(images, num_classes,
reduction_cell = nasnet_utils.NasNetAReductionCell( reduction_cell = nasnet_utils.NasNetAReductionCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob, hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps) total_num_cells, hparams.total_training_steps)
with arg_scope([slim.dropout, nasnet_utils.drop_path], with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
is_training=is_training): is_training=is_training):
with arg_scope([slim.batch_norm], is_training=is_batchnorm_training): with arg_scope([slim.avg_pool2d,
with arg_scope([slim.avg_pool2d, slim.max_pool2d,
slim.max_pool2d, slim.conv2d,
slim.conv2d, slim.batch_norm,
slim.batch_norm, slim.separable_conv2d,
slim.separable_conv2d, nasnet_utils.factorized_reduction,
nasnet_utils.factorized_reduction, nasnet_utils.global_avg_pool,
nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index,
nasnet_utils.get_channel_index, nasnet_utils.get_channel_dim],
nasnet_utils.get_channel_dim], data_format=hparams.data_format):
data_format=hparams.data_format): return _build_nasnet_base(images,
return _build_nasnet_base(images, normal_cell=normal_cell,
normal_cell=normal_cell, reduction_cell=reduction_cell,
reduction_cell=reduction_cell, num_classes=num_classes,
num_classes=num_classes, hparams=hparams,
hparams=hparams, is_training=is_training,
is_training=is_training, stem_type='imagenet',
stem_type='imagenet', final_endpoint=final_endpoint)
final_endpoint=final_endpoint)
build_nasnet_large.default_image_size = 331 build_nasnet_large.default_image_size = 331
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment