Commit 696b69a4 authored by Chenxi Liu's avatar Chenxi Liu Committed by Sergio Guadarrama
Browse files

Internal changes including PNASNet-5 mobile (#4895)

* PiperOrigin-RevId: 201234832

* PiperOrigin-RevId: 202507333

* PiperOrigin-RevId: 204320344

* Add PNASNet-5 mobile network model and cell structure.

PiperOrigin-RevId: 204735410

* Add option to customize individual projection layer activation.

PiperOrigin-RevId: 204776951
parent 2d7a0d6a
...@@ -269,6 +269,7 @@ Model | TF-Slim File | Checkpoint | Top-1 Accuracy| Top-5 Accuracy | ...@@ -269,6 +269,7 @@ Model | TF-Slim File | Checkpoint | Top-1 Accuracy| Top-5 Accuracy |
[NASNet-A_Mobile_224](https://arxiv.org/abs/1707.07012)#|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/nasnet/nasnet.py)|[nasnet-a_mobile_04_10_2017.tar.gz](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz)|74.0|91.6| [NASNet-A_Mobile_224](https://arxiv.org/abs/1707.07012)#|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/nasnet/nasnet.py)|[nasnet-a_mobile_04_10_2017.tar.gz](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz)|74.0|91.6|
[NASNet-A_Large_331](https://arxiv.org/abs/1707.07012)#|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/nasnet/nasnet.py)|[nasnet-a_large_04_10_2017.tar.gz](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_large_04_10_2017.tar.gz)|82.7|96.2| [NASNet-A_Large_331](https://arxiv.org/abs/1707.07012)#|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/nasnet/nasnet.py)|[nasnet-a_large_04_10_2017.tar.gz](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_large_04_10_2017.tar.gz)|82.7|96.2|
[PNASNet-5_Large_331](https://arxiv.org/abs/1712.00559)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/nasnet/pnasnet.py)|[pnasnet-5_large_2017_12_13.tar.gz](https://storage.googleapis.com/download.tensorflow.org/models/pnasnet-5_large_2017_12_13.tar.gz)|82.9|96.2| [PNASNet-5_Large_331](https://arxiv.org/abs/1712.00559)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/nasnet/pnasnet.py)|[pnasnet-5_large_2017_12_13.tar.gz](https://storage.googleapis.com/download.tensorflow.org/models/pnasnet-5_large_2017_12_13.tar.gz)|82.9|96.2|
[PNASNet-5_Mobile_224](https://arxiv.org/abs/1712.00559)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/nasnet/pnasnet.py)|[pnasnet-5_mobile_2017_12_13.tar.gz](https://storage.googleapis.com/download.tensorflow.org/models/pnasnet-5_mobile_2017_12_13.tar.gz)|74.2|91.9|
^ ResNet V2 models use Inception pre-processing and input image size of 299 (use ^ ResNet V2 models use Inception pre-processing and input image size of 299 (use
`--preprocessing_name inception --eval_image_size 299` when using `--preprocessing_name inception --eval_image_size 299` when using
......
...@@ -365,10 +365,12 @@ def inception_resnet_v2(inputs, num_classes=1001, is_training=True, ...@@ -365,10 +365,12 @@ def inception_resnet_v2(inputs, num_classes=1001, is_training=True,
inception_resnet_v2.default_image_size = 299 inception_resnet_v2.default_image_size = 299
def inception_resnet_v2_arg_scope(weight_decay=0.00004, def inception_resnet_v2_arg_scope(
batch_norm_decay=0.9997, weight_decay=0.00004,
batch_norm_epsilon=0.001, batch_norm_decay=0.9997,
activation_fn=tf.nn.relu): batch_norm_epsilon=0.001,
activation_fn=tf.nn.relu,
batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS):
"""Returns the scope with the default parameters for inception_resnet_v2. """Returns the scope with the default parameters for inception_resnet_v2.
Args: Args:
...@@ -376,6 +378,8 @@ def inception_resnet_v2_arg_scope(weight_decay=0.00004, ...@@ -376,6 +378,8 @@ def inception_resnet_v2_arg_scope(weight_decay=0.00004,
batch_norm_decay: decay for the moving average of batch_norm momentums. batch_norm_decay: decay for the moving average of batch_norm momentums.
batch_norm_epsilon: small float added to variance to avoid dividing by zero. batch_norm_epsilon: small float added to variance to avoid dividing by zero.
activation_fn: Activation function for conv2d. activation_fn: Activation function for conv2d.
batch_norm_updates_collections: Collection for the update ops for
batch norm.
Returns: Returns:
a arg_scope with the parameters needed for inception_resnet_v2. a arg_scope with the parameters needed for inception_resnet_v2.
...@@ -388,6 +392,7 @@ def inception_resnet_v2_arg_scope(weight_decay=0.00004, ...@@ -388,6 +392,7 @@ def inception_resnet_v2_arg_scope(weight_decay=0.00004,
batch_norm_params = { batch_norm_params = {
'decay': batch_norm_decay, 'decay': batch_norm_decay,
'epsilon': batch_norm_epsilon, 'epsilon': batch_norm_epsilon,
'updates_collections': batch_norm_updates_collections,
'fused': None, # Use fused batch norm if possible. 'fused': None, # Use fused batch norm if possible.
} }
# Set activation_fn and parameters for batch_norm. # Set activation_fn and parameters for batch_norm.
......
...@@ -33,7 +33,8 @@ def inception_arg_scope(weight_decay=0.00004, ...@@ -33,7 +33,8 @@ def inception_arg_scope(weight_decay=0.00004,
use_batch_norm=True, use_batch_norm=True,
batch_norm_decay=0.9997, batch_norm_decay=0.9997,
batch_norm_epsilon=0.001, batch_norm_epsilon=0.001,
activation_fn=tf.nn.relu): activation_fn=tf.nn.relu,
batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS):
"""Defines the default arg scope for inception models. """Defines the default arg scope for inception models.
Args: Args:
...@@ -43,6 +44,8 @@ def inception_arg_scope(weight_decay=0.00004, ...@@ -43,6 +44,8 @@ def inception_arg_scope(weight_decay=0.00004,
batch_norm_epsilon: Small float added to variance to avoid dividing by zero batch_norm_epsilon: Small float added to variance to avoid dividing by zero
in batch norm. in batch norm.
activation_fn: Activation function for conv2d. activation_fn: Activation function for conv2d.
batch_norm_updates_collections: Collection for the update ops for
batch norm.
Returns: Returns:
An `arg_scope` to use for the inception models. An `arg_scope` to use for the inception models.
...@@ -53,7 +56,7 @@ def inception_arg_scope(weight_decay=0.00004, ...@@ -53,7 +56,7 @@ def inception_arg_scope(weight_decay=0.00004,
# epsilon to prevent 0s in variance. # epsilon to prevent 0s in variance.
'epsilon': batch_norm_epsilon, 'epsilon': batch_norm_epsilon,
# collection containing update_ops. # collection containing update_ops.
'updates_collections': tf.GraphKeys.UPDATE_OPS, 'updates_collections': batch_norm_updates_collections,
# use fused batch norm if possible. # use fused batch norm if possible.
'fused': None, 'fused': None,
} }
......
...@@ -168,6 +168,7 @@ def expanded_conv(input_tensor, ...@@ -168,6 +168,7 @@ def expanded_conv(input_tensor,
kernel_size=(3, 3), kernel_size=(3, 3),
residual=True, residual=True,
normalizer_fn=None, normalizer_fn=None,
project_activation_fn=tf.identity,
split_projection=1, split_projection=1,
split_expansion=1, split_expansion=1,
expansion_transform=None, expansion_transform=None,
...@@ -195,6 +196,7 @@ def expanded_conv(input_tensor, ...@@ -195,6 +196,7 @@ def expanded_conv(input_tensor,
residual: whether to include residual connection between input residual: whether to include residual connection between input
and output. and output.
normalizer_fn: batchnorm or otherwise normalizer_fn: batchnorm or otherwise
project_activation_fn: activation function for the project layer
split_projection: how many ways to split projection operator split_projection: how many ways to split projection operator
(that is conv expansion->bottleneck) (that is conv expansion->bottleneck)
split_expansion: how many ways to split expansion op split_expansion: how many ways to split expansion op
...@@ -291,7 +293,7 @@ def expanded_conv(input_tensor, ...@@ -291,7 +293,7 @@ def expanded_conv(input_tensor,
stride=1, stride=1,
scope='project', scope='project',
normalizer_fn=normalizer_fn, normalizer_fn=normalizer_fn,
activation_fn=tf.identity) activation_fn=project_activation_fn)
if endpoints is not None: if endpoints is not None:
endpoints['projection_output'] = net endpoints['projection_output'] = net
if depthwise_location == 'output': if depthwise_location == 'output':
......
...@@ -25,6 +25,7 @@ from __future__ import division ...@@ -25,6 +25,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import copy import copy
import functools
import tensorflow as tf import tensorflow as tf
...@@ -154,6 +155,22 @@ def mobilenet(input_tensor, ...@@ -154,6 +155,22 @@ def mobilenet(input_tensor,
**kwargs) **kwargs)
def wrapped_partial(func, *args, **kwargs):
partial_func = functools.partial(func, *args, **kwargs)
functools.update_wrapper(partial_func, func)
return partial_func
# Wrappers for mobilenet v2 with depth-multipliers. Be noticed that
# 'finegrain_classification_mode' is set to True, which means the embedding
# layer will not be shrinked when given a depth-multiplier < 1.0.
mobilenet_v2_140 = wrapped_partial(mobilenet, depth_multiplier=1.4)
mobilenet_v2_050 = wrapped_partial(mobilenet, depth_multiplier=0.50,
finegrain_classification_mode=True)
mobilenet_v2_035 = wrapped_partial(mobilenet, depth_multiplier=0.35,
finegrain_classification_mode=True)
@slim.add_arg_scope @slim.add_arg_scope
def mobilenet_base(input_tensor, depth_multiplier=1.0, **kwargs): def mobilenet_base(input_tensor, depth_multiplier=1.0, **kwargs):
"""Creates base of the mobilenet (no pooling and no logits) .""" """Creates base of the mobilenet (no pooling and no logits) ."""
......
...@@ -425,12 +425,14 @@ def _reduced_kernel_size_for_small_input(input_tensor, kernel_size): ...@@ -425,12 +425,14 @@ def _reduced_kernel_size_for_small_input(input_tensor, kernel_size):
return kernel_size_out return kernel_size_out
def mobilenet_v1_arg_scope(is_training=True, def mobilenet_v1_arg_scope(
weight_decay=0.00004, is_training=True,
stddev=0.09, weight_decay=0.00004,
regularize_depthwise=False, stddev=0.09,
batch_norm_decay=0.9997, regularize_depthwise=False,
batch_norm_epsilon=0.001): batch_norm_decay=0.9997,
batch_norm_epsilon=0.001,
batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS):
"""Defines the default MobilenetV1 arg scope. """Defines the default MobilenetV1 arg scope.
Args: Args:
...@@ -442,6 +444,8 @@ def mobilenet_v1_arg_scope(is_training=True, ...@@ -442,6 +444,8 @@ def mobilenet_v1_arg_scope(is_training=True,
batch_norm_decay: Decay for batch norm moving average. batch_norm_decay: Decay for batch norm moving average.
batch_norm_epsilon: Small float added to variance to avoid dividing by zero batch_norm_epsilon: Small float added to variance to avoid dividing by zero
in batch norm. in batch norm.
batch_norm_updates_collections: Collection for the update ops for
batch norm.
Returns: Returns:
An `arg_scope` to use for the mobilenet v1 model. An `arg_scope` to use for the mobilenet v1 model.
...@@ -451,6 +455,7 @@ def mobilenet_v1_arg_scope(is_training=True, ...@@ -451,6 +455,7 @@ def mobilenet_v1_arg_scope(is_training=True,
'scale': True, 'scale': True,
'decay': batch_norm_decay, 'decay': batch_norm_decay,
'epsilon': batch_norm_epsilon, 'epsilon': batch_norm_epsilon,
'updates_collections': batch_norm_updates_collections,
} }
if is_training is not None: if is_training is not None:
batch_norm_params['is_training'] = is_training batch_norm_params['is_training'] = is_training
......
...@@ -86,8 +86,6 @@ def global_avg_pool(x, data_format=INVALID): ...@@ -86,8 +86,6 @@ def global_avg_pool(x, data_format=INVALID):
@tf.contrib.framework.add_arg_scope @tf.contrib.framework.add_arg_scope
def factorized_reduction(net, output_filters, stride, data_format=INVALID): def factorized_reduction(net, output_filters, stride, data_format=INVALID):
"""Reduces the shape of net without information loss due to striding.""" """Reduces the shape of net without information loss due to striding."""
assert output_filters % 2 == 0, (
'Need even number of filters when using this factorized reduction.')
assert data_format != INVALID assert data_format != INVALID
if stride == 1: if stride == 1:
net = slim.conv2d(net, output_filters, 1, scope='path_conv') net = slim.conv2d(net, output_filters, 1, scope='path_conv')
...@@ -117,7 +115,10 @@ def factorized_reduction(net, output_filters, stride, data_format=INVALID): ...@@ -117,7 +115,10 @@ def factorized_reduction(net, output_filters, stride, data_format=INVALID):
path2 = tf.nn.avg_pool( path2 = tf.nn.avg_pool(
path2, [1, 1, 1, 1], stride_spec, 'VALID', data_format=data_format) path2, [1, 1, 1, 1], stride_spec, 'VALID', data_format=data_format)
path2 = slim.conv2d(path2, int(output_filters / 2), 1, scope='path2_conv')
# If odd number of filters, add an additional one to the second path.
final_filter_size = int(output_filters / 2) + int(output_filters % 2)
path2 = slim.conv2d(path2, final_filter_size, 1, scope='path2_conv')
# Concat and apply BN # Concat and apply BN
final_path = tf.concat(values=[path1, path2], axis=concat_axis) final_path = tf.concat(values=[path1, path2], axis=concat_axis)
......
...@@ -43,6 +43,24 @@ def large_imagenet_config(): ...@@ -43,6 +43,24 @@ def large_imagenet_config():
use_aux_head=1, use_aux_head=1,
num_reduction_layers=2, num_reduction_layers=2,
data_format='NHWC', data_format='NHWC',
skip_reduction_layer_input=1,
total_training_steps=250000,
)
def mobile_imagenet_config():
"""Mobile ImageNet configuration based on PNASNet-5."""
return tf.contrib.training.HParams(
stem_multiplier=1.0,
dense_dropout_keep_prob=0.5,
num_cells=9,
filter_scaling_rate=2.0,
num_conv_filters=54,
drop_path_keep_prob=1.0,
use_aux_head=1,
num_reduction_layers=2,
data_format='NHWC',
skip_reduction_layer_input=1,
total_training_steps=250000, total_training_steps=250000,
) )
...@@ -54,6 +72,14 @@ def pnasnet_large_arg_scope(weight_decay=4e-5, batch_norm_decay=0.9997, ...@@ -54,6 +72,14 @@ def pnasnet_large_arg_scope(weight_decay=4e-5, batch_norm_decay=0.9997,
weight_decay, batch_norm_decay, batch_norm_epsilon) weight_decay, batch_norm_decay, batch_norm_epsilon)
def pnasnet_mobile_arg_scope(weight_decay=4e-5,
batch_norm_decay=0.9997,
batch_norm_epsilon=0.001):
"""Default arg scope for the PNASNet Mobile ImageNet model."""
return nasnet.nasnet_mobile_arg_scope(weight_decay, batch_norm_decay,
batch_norm_epsilon)
def _build_pnasnet_base(images, def _build_pnasnet_base(images,
normal_cell, normal_cell,
num_classes, num_classes,
...@@ -92,7 +118,8 @@ def _build_pnasnet_base(images, ...@@ -92,7 +118,8 @@ def _build_pnasnet_base(images,
is_reduction = cell_num in reduction_indices is_reduction = cell_num in reduction_indices
stride = 2 if is_reduction else 1 stride = 2 if is_reduction else 1
if is_reduction: filter_scaling *= hparams.filter_scaling_rate if is_reduction: filter_scaling *= hparams.filter_scaling_rate
prev_layer = cell_outputs[-2] if hparams.skip_reduction_layer_input or not is_reduction:
prev_layer = cell_outputs[-2]
net = normal_cell( net = normal_cell(
net, net,
scope='cell_{}'.format(cell_num), scope='cell_{}'.format(cell_num),
...@@ -178,6 +205,56 @@ def build_pnasnet_large(images, ...@@ -178,6 +205,56 @@ def build_pnasnet_large(images,
build_pnasnet_large.default_image_size = 331 build_pnasnet_large.default_image_size = 331
def build_pnasnet_mobile(images,
num_classes,
is_training=True,
final_endpoint=None,
config=None):
"""Build PNASNet Mobile model for the ImageNet Dataset."""
hparams = copy.deepcopy(config) if config else mobile_imagenet_config()
# pylint: disable=protected-access
nasnet._update_hparams(hparams, is_training)
# pylint: enable=protected-access
if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
tf.logging.info('A GPU is available on the machine, consider using NCHW '
'data format for increased speed on GPU.')
if hparams.data_format == 'NCHW':
images = tf.transpose(images, [0, 3, 1, 2])
# Calculate the total number of cells in the network.
# There is no distinction between reduction and normal cells in PNAS so the
# total number of cells is equal to the number normal cells plus the number
# of stem cells (two by default).
total_num_cells = hparams.num_cells + 2
normal_cell = PNasNetNormalCell(hparams.num_conv_filters,
hparams.drop_path_keep_prob, total_num_cells,
hparams.total_training_steps)
with arg_scope(
[slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
is_training=is_training):
with arg_scope(
[
slim.avg_pool2d, slim.max_pool2d, slim.conv2d, slim.batch_norm,
slim.separable_conv2d, nasnet_utils.factorized_reduction,
nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index,
nasnet_utils.get_channel_dim
],
data_format=hparams.data_format):
return _build_pnasnet_base(
images,
normal_cell=normal_cell,
num_classes=num_classes,
hparams=hparams,
is_training=is_training,
final_endpoint=final_endpoint)
build_pnasnet_mobile.default_image_size = 224
class PNasNetNormalCell(nasnet_utils.NasNetABaseCell): class PNasNetNormalCell(nasnet_utils.NasNetABaseCell):
"""PNASNet Normal Cell.""" """PNASNet Normal Cell."""
......
...@@ -43,6 +43,43 @@ class PNASNetTest(tf.test.TestCase): ...@@ -43,6 +43,43 @@ class PNASNetTest(tf.test.TestCase):
self.assertListEqual(predictions.get_shape().as_list(), self.assertListEqual(predictions.get_shape().as_list(),
[batch_size, num_classes]) [batch_size, num_classes])
def testBuildLogitsMobileModel(self):
batch_size = 5
height, width = 224, 224
num_classes = 1000
inputs = tf.random_uniform((batch_size, height, width, 3))
tf.train.create_global_step()
with slim.arg_scope(pnasnet.pnasnet_mobile_arg_scope()):
logits, end_points = pnasnet.build_pnasnet_mobile(inputs, num_classes)
auxlogits = end_points['AuxLogits']
predictions = end_points['Predictions']
self.assertListEqual(auxlogits.get_shape().as_list(),
[batch_size, num_classes])
self.assertListEqual(logits.get_shape().as_list(),
[batch_size, num_classes])
self.assertListEqual(predictions.get_shape().as_list(),
[batch_size, num_classes])
def testBuildNonExistingLayerLargeModel(self):
"""Tests that the model is built correctly without unnecessary layers."""
inputs = tf.random_uniform((5, 331, 331, 3))
tf.train.create_global_step()
with slim.arg_scope(pnasnet.pnasnet_large_arg_scope()):
pnasnet.build_pnasnet_large(inputs, 1000)
vars_names = [x.op.name for x in tf.trainable_variables()]
self.assertIn('cell_stem_0/1x1/weights', vars_names)
self.assertNotIn('cell_stem_1/comb_iter_0/right/1x1/weights', vars_names)
def testBuildNonExistingLayerMobileModel(self):
"""Tests that the model is built correctly without unnecessary layers."""
inputs = tf.random_uniform((5, 224, 224, 3))
tf.train.create_global_step()
with slim.arg_scope(pnasnet.pnasnet_mobile_arg_scope()):
pnasnet.build_pnasnet_mobile(inputs, 1000)
vars_names = [x.op.name for x in tf.trainable_variables()]
self.assertIn('cell_stem_0/1x1/weights', vars_names)
self.assertNotIn('cell_stem_1/comb_iter_0/right/1x1/weights', vars_names)
def testBuildPreLogitsLargeModel(self): def testBuildPreLogitsLargeModel(self):
batch_size = 5 batch_size = 5
height, width = 331, 331 height, width = 331, 331
...@@ -56,6 +93,19 @@ class PNASNetTest(tf.test.TestCase): ...@@ -56,6 +93,19 @@ class PNASNetTest(tf.test.TestCase):
self.assertTrue(net.op.name.startswith('final_layer/Mean')) self.assertTrue(net.op.name.startswith('final_layer/Mean'))
self.assertListEqual(net.get_shape().as_list(), [batch_size, 4320]) self.assertListEqual(net.get_shape().as_list(), [batch_size, 4320])
def testBuildPreLogitsMobileModel(self):
batch_size = 5
height, width = 224, 224
num_classes = None
inputs = tf.random_uniform((batch_size, height, width, 3))
tf.train.create_global_step()
with slim.arg_scope(pnasnet.pnasnet_mobile_arg_scope()):
net, end_points = pnasnet.build_pnasnet_mobile(inputs, num_classes)
self.assertFalse('AuxLogits' in end_points)
self.assertFalse('Predictions' in end_points)
self.assertTrue(net.op.name.startswith('final_layer/Mean'))
self.assertListEqual(net.get_shape().as_list(), [batch_size, 1080])
def testAllEndPointsShapesLargeModel(self): def testAllEndPointsShapesLargeModel(self):
batch_size = 5 batch_size = 5
height, width = 331, 331 height, width = 331, 331
...@@ -93,6 +143,41 @@ class PNASNetTest(tf.test.TestCase): ...@@ -93,6 +143,41 @@ class PNASNetTest(tf.test.TestCase):
self.assertListEqual(end_points[endpoint_name].get_shape().as_list(), self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
expected_shape) expected_shape)
def testAllEndPointsShapesMobileModel(self):
batch_size = 5
height, width = 224, 224
num_classes = 1000
inputs = tf.random_uniform((batch_size, height, width, 3))
tf.train.create_global_step()
with slim.arg_scope(pnasnet.pnasnet_mobile_arg_scope()):
_, end_points = pnasnet.build_pnasnet_mobile(inputs, num_classes)
endpoints_shapes = {
'Stem': [batch_size, 28, 28, 135],
'Cell_0': [batch_size, 28, 28, 270],
'Cell_1': [batch_size, 28, 28, 270],
'Cell_2': [batch_size, 28, 28, 270],
'Cell_3': [batch_size, 14, 14, 540],
'Cell_4': [batch_size, 14, 14, 540],
'Cell_5': [batch_size, 14, 14, 540],
'Cell_6': [batch_size, 7, 7, 1080],
'Cell_7': [batch_size, 7, 7, 1080],
'Cell_8': [batch_size, 7, 7, 1080],
'global_pool': [batch_size, 1080],
# Logits and predictions
'AuxLogits': [batch_size, num_classes],
'Predictions': [batch_size, num_classes],
'Logits': [batch_size, num_classes],
}
self.assertEqual(len(end_points), 14)
self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
for endpoint_name in endpoints_shapes:
tf.logging.info('Endpoint name: {}'.format(endpoint_name))
expected_shape = endpoints_shapes[endpoint_name]
self.assertIn(endpoint_name, end_points)
self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
expected_shape)
def testNoAuxHeadLargeModel(self): def testNoAuxHeadLargeModel(self):
batch_size = 5 batch_size = 5
height, width = 331, 331 height, width = 331, 331
...@@ -108,6 +193,21 @@ class PNASNetTest(tf.test.TestCase): ...@@ -108,6 +193,21 @@ class PNASNetTest(tf.test.TestCase):
config=config) config=config)
self.assertEqual('AuxLogits' in end_points, use_aux_head) self.assertEqual('AuxLogits' in end_points, use_aux_head)
def testNoAuxHeadMobileModel(self):
batch_size = 5
height, width = 224, 224
num_classes = 1000
for use_aux_head in (True, False):
tf.reset_default_graph()
inputs = tf.random_uniform((batch_size, height, width, 3))
tf.train.create_global_step()
config = pnasnet.mobile_imagenet_config()
config.set_hparam('use_aux_head', int(use_aux_head))
with slim.arg_scope(pnasnet.pnasnet_mobile_arg_scope()):
_, end_points = pnasnet.build_pnasnet_mobile(
inputs, num_classes, config=config)
self.assertEqual('AuxLogits' in end_points, use_aux_head)
def testOverrideHParamsLargeModel(self): def testOverrideHParamsLargeModel(self):
batch_size = 5 batch_size = 5
height, width = 331, 331 height, width = 331, 331
...@@ -122,6 +222,20 @@ class PNASNetTest(tf.test.TestCase): ...@@ -122,6 +222,20 @@ class PNASNetTest(tf.test.TestCase):
self.assertListEqual( self.assertListEqual(
end_points['Stem'].shape.as_list(), [batch_size, 540, 42, 42]) end_points['Stem'].shape.as_list(), [batch_size, 540, 42, 42])
def testOverrideHParamsMobileModel(self):
batch_size = 5
height, width = 224, 224
num_classes = 1000
inputs = tf.random_uniform((batch_size, height, width, 3))
tf.train.create_global_step()
config = pnasnet.mobile_imagenet_config()
config.set_hparam('data_format', 'NCHW')
with slim.arg_scope(pnasnet.pnasnet_mobile_arg_scope()):
_, end_points = pnasnet.build_pnasnet_mobile(
inputs, num_classes, config=config)
self.assertListEqual(end_points['Stem'].shape.as_list(),
[batch_size, 135, 28, 28])
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -61,10 +61,13 @@ networks_map = {'alexnet_v2': alexnet.alexnet_v2, ...@@ -61,10 +61,13 @@ networks_map = {'alexnet_v2': alexnet.alexnet_v2,
'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_050, 'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_050,
'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_025, 'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_025,
'mobilenet_v2': mobilenet_v2.mobilenet, 'mobilenet_v2': mobilenet_v2.mobilenet,
'mobilenet_v2_140': mobilenet_v2.mobilenet_v2_140,
'mobilenet_v2_035': mobilenet_v2.mobilenet_v2_035,
'nasnet_cifar': nasnet.build_nasnet_cifar, 'nasnet_cifar': nasnet.build_nasnet_cifar,
'nasnet_mobile': nasnet.build_nasnet_mobile, 'nasnet_mobile': nasnet.build_nasnet_mobile,
'nasnet_large': nasnet.build_nasnet_large, 'nasnet_large': nasnet.build_nasnet_large,
'pnasnet_large': pnasnet.build_pnasnet_large, 'pnasnet_large': pnasnet.build_pnasnet_large,
'pnasnet_mobile': pnasnet.build_pnasnet_mobile,
} }
arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope, arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope,
...@@ -93,10 +96,13 @@ arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope, ...@@ -93,10 +96,13 @@ arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope,
'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_arg_scope, 'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_arg_scope,
'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_arg_scope, 'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_arg_scope,
'mobilenet_v2': mobilenet_v2.training_scope, 'mobilenet_v2': mobilenet_v2.training_scope,
'mobilenet_v2_035': mobilenet_v2.training_scope,
'mobilenet_v2_140': mobilenet_v2.training_scope,
'nasnet_cifar': nasnet.nasnet_cifar_arg_scope, 'nasnet_cifar': nasnet.nasnet_cifar_arg_scope,
'nasnet_mobile': nasnet.nasnet_mobile_arg_scope, 'nasnet_mobile': nasnet.nasnet_mobile_arg_scope,
'nasnet_large': nasnet.nasnet_large_arg_scope, 'nasnet_large': nasnet.nasnet_large_arg_scope,
'pnasnet_large': pnasnet.pnasnet_large_arg_scope, 'pnasnet_large': pnasnet.pnasnet_large_arg_scope,
'pnasnet_mobile': pnasnet.pnasnet_mobile_arg_scope,
} }
......
...@@ -224,7 +224,8 @@ def resnet_arg_scope(weight_decay=0.0001, ...@@ -224,7 +224,8 @@ def resnet_arg_scope(weight_decay=0.0001,
batch_norm_epsilon=1e-5, batch_norm_epsilon=1e-5,
batch_norm_scale=True, batch_norm_scale=True,
activation_fn=tf.nn.relu, activation_fn=tf.nn.relu,
use_batch_norm=True): use_batch_norm=True,
batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS):
"""Defines the default ResNet arg scope. """Defines the default ResNet arg scope.
TODO(gpapan): The batch-normalization related default values above are TODO(gpapan): The batch-normalization related default values above are
...@@ -242,6 +243,8 @@ def resnet_arg_scope(weight_decay=0.0001, ...@@ -242,6 +243,8 @@ def resnet_arg_scope(weight_decay=0.0001,
activations in the batch normalization layer. activations in the batch normalization layer.
activation_fn: The activation function which is used in ResNet. activation_fn: The activation function which is used in ResNet.
use_batch_norm: Whether or not to use batch normalization. use_batch_norm: Whether or not to use batch normalization.
batch_norm_updates_collections: Collection for the update ops for
batch norm.
Returns: Returns:
An `arg_scope` to use for the resnet models. An `arg_scope` to use for the resnet models.
...@@ -250,7 +253,7 @@ def resnet_arg_scope(weight_decay=0.0001, ...@@ -250,7 +253,7 @@ def resnet_arg_scope(weight_decay=0.0001,
'decay': batch_norm_decay, 'decay': batch_norm_decay,
'epsilon': batch_norm_epsilon, 'epsilon': batch_norm_epsilon,
'scale': batch_norm_scale, 'scale': batch_norm_scale,
'updates_collections': tf.GraphKeys.UPDATE_OPS, 'updates_collections': batch_norm_updates_collections,
'fused': None, # Use fused batch norm if possible. 'fused': None, # Use fused batch norm if possible.
} }
......
...@@ -54,8 +54,12 @@ def get_preprocessing(name, is_training=False): ...@@ -54,8 +54,12 @@ def get_preprocessing(name, is_training=False):
'inception_resnet_v2': inception_preprocessing, 'inception_resnet_v2': inception_preprocessing,
'lenet': lenet_preprocessing, 'lenet': lenet_preprocessing,
'mobilenet_v1': inception_preprocessing, 'mobilenet_v1': inception_preprocessing,
'mobilenet_v2': inception_preprocessing,
'mobilenet_v2_035': inception_preprocessing,
'mobilenet_v2_140': inception_preprocessing,
'nasnet_mobile': inception_preprocessing, 'nasnet_mobile': inception_preprocessing,
'nasnet_large': inception_preprocessing, 'nasnet_large': inception_preprocessing,
'pnasnet_mobile': inception_preprocessing,
'pnasnet_large': inception_preprocessing, 'pnasnet_large': inception_preprocessing,
'resnet_v1_50': vgg_preprocessing, 'resnet_v1_50': vgg_preprocessing,
'resnet_v1_101': vgg_preprocessing, 'resnet_v1_101': vgg_preprocessing,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment