Unverified Commit 3f5dbba9 authored by Jonathan Huang's avatar Jonathan Huang Committed by GitHub
Browse files

Merge pull request #2685 from tensorflow/nas-fix-3

Fix nasnet image classification and object detection
parents b3f04bca cf8f014b
...@@ -30,6 +30,23 @@ arg_scope = tf.contrib.framework.arg_scope ...@@ -30,6 +30,23 @@ arg_scope = tf.contrib.framework.arg_scope
slim = tf.contrib.slim slim = tf.contrib.slim
def nasnet_large_arg_scope_for_detection(is_batch_norm_training=False):
"""Defines the default arg scope for the NASNet-A Large for object detection.
This provides a small edit to switch batch norm training on and off.
Args:
is_batch_norm_training: Boolean indicating whether to train with batch norm.
Returns:
An `arg_scope` to use for the NASNet Large Model.
"""
imagenet_scope = nasnet.nasnet_large_arg_scope()
with arg_scope(imagenet_scope):
with arg_scope([slim.batch_norm], is_training=is_batch_norm_training) as sc:
return sc
# Note: This is largely a copy of _build_nasnet_base inside nasnet.py but # Note: This is largely a copy of _build_nasnet_base inside nasnet.py but
# with special edits to remove instantiation of the stem and the special # with special edits to remove instantiation of the stem and the special
# ability to receive as input a pair of hidden states. # ability to receive as input a pair of hidden states.
...@@ -163,11 +180,11 @@ class FasterRCNNNASFeatureExtractor( ...@@ -163,11 +180,11 @@ class FasterRCNNNASFeatureExtractor(
raise ValueError('`preprocessed_inputs` must be 4 dimensional, got a ' raise ValueError('`preprocessed_inputs` must be 4 dimensional, got a '
'tensor of shape %s' % preprocessed_inputs.get_shape()) 'tensor of shape %s' % preprocessed_inputs.get_shape())
with slim.arg_scope(nasnet.nasnet_large_arg_scope()): with slim.arg_scope(nasnet_large_arg_scope_for_detection(
is_batch_norm_training=self._train_batch_norm)):
_, end_points = nasnet.build_nasnet_large( _, end_points = nasnet.build_nasnet_large(
preprocessed_inputs, num_classes=None, preprocessed_inputs, num_classes=None,
is_training=self._is_training, is_training=self._is_training,
is_batchnorm_training=self._train_batch_norm,
final_endpoint='Cell_11') final_endpoint='Cell_11')
# Note that both 'Cell_10' and 'Cell_11' have equal depth = 2016. # Note that both 'Cell_10' and 'Cell_11' have equal depth = 2016.
......
...@@ -324,7 +324,7 @@ build_nasnet_cifar.default_image_size = 32 ...@@ -324,7 +324,7 @@ build_nasnet_cifar.default_image_size = 32
def build_nasnet_mobile(images, num_classes, def build_nasnet_mobile(images, num_classes,
is_training=True, is_batchnorm_training=True, is_training=True,
final_endpoint=None): final_endpoint=None):
"""Build NASNet Mobile model for the ImageNet Dataset.""" """Build NASNet Mobile model for the ImageNet Dataset."""
hparams = _mobile_imagenet_config() hparams = _mobile_imagenet_config()
...@@ -348,32 +348,31 @@ def build_nasnet_mobile(images, num_classes, ...@@ -348,32 +348,31 @@ def build_nasnet_mobile(images, num_classes,
reduction_cell = nasnet_utils.NasNetAReductionCell( reduction_cell = nasnet_utils.NasNetAReductionCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob, hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps) total_num_cells, hparams.total_training_steps)
with arg_scope([slim.dropout, nasnet_utils.drop_path], with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
is_training=is_training): is_training=is_training):
with arg_scope([slim.batch_norm], is_training=is_batchnorm_training): with arg_scope([slim.avg_pool2d,
with arg_scope([slim.avg_pool2d, slim.max_pool2d,
slim.max_pool2d, slim.conv2d,
slim.conv2d, slim.batch_norm,
slim.batch_norm, slim.separable_conv2d,
slim.separable_conv2d, nasnet_utils.factorized_reduction,
nasnet_utils.factorized_reduction, nasnet_utils.global_avg_pool,
nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index,
nasnet_utils.get_channel_index, nasnet_utils.get_channel_dim],
nasnet_utils.get_channel_dim], data_format=hparams.data_format):
data_format=hparams.data_format): return _build_nasnet_base(images,
return _build_nasnet_base(images, normal_cell=normal_cell,
normal_cell=normal_cell, reduction_cell=reduction_cell,
reduction_cell=reduction_cell, num_classes=num_classes,
num_classes=num_classes, hparams=hparams,
hparams=hparams, is_training=is_training,
is_training=is_training, stem_type='imagenet',
stem_type='imagenet', final_endpoint=final_endpoint)
final_endpoint=final_endpoint)
build_nasnet_mobile.default_image_size = 224 build_nasnet_mobile.default_image_size = 224
def build_nasnet_large(images, num_classes, def build_nasnet_large(images, num_classes,
is_training=True, is_batchnorm_training=True, is_training=True,
final_endpoint=None): final_endpoint=None):
"""Build NASNet Large model for the ImageNet Dataset.""" """Build NASNet Large model for the ImageNet Dataset."""
hparams = _large_imagenet_config(is_training=is_training) hparams = _large_imagenet_config(is_training=is_training)
...@@ -397,27 +396,26 @@ def build_nasnet_large(images, num_classes, ...@@ -397,27 +396,26 @@ def build_nasnet_large(images, num_classes,
reduction_cell = nasnet_utils.NasNetAReductionCell( reduction_cell = nasnet_utils.NasNetAReductionCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob, hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps) total_num_cells, hparams.total_training_steps)
with arg_scope([slim.dropout, nasnet_utils.drop_path], with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
is_training=is_training): is_training=is_training):
with arg_scope([slim.batch_norm], is_training=is_batchnorm_training): with arg_scope([slim.avg_pool2d,
with arg_scope([slim.avg_pool2d, slim.max_pool2d,
slim.max_pool2d, slim.conv2d,
slim.conv2d, slim.batch_norm,
slim.batch_norm, slim.separable_conv2d,
slim.separable_conv2d, nasnet_utils.factorized_reduction,
nasnet_utils.factorized_reduction, nasnet_utils.global_avg_pool,
nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index,
nasnet_utils.get_channel_index, nasnet_utils.get_channel_dim],
nasnet_utils.get_channel_dim], data_format=hparams.data_format):
data_format=hparams.data_format): return _build_nasnet_base(images,
return _build_nasnet_base(images, normal_cell=normal_cell,
normal_cell=normal_cell, reduction_cell=reduction_cell,
reduction_cell=reduction_cell, num_classes=num_classes,
num_classes=num_classes, hparams=hparams,
hparams=hparams, is_training=is_training,
is_training=is_training, stem_type='imagenet',
stem_type='imagenet', final_endpoint=final_endpoint)
final_endpoint=final_endpoint)
build_nasnet_large.default_image_size = 331 build_nasnet_large.default_image_size = 331
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment