Commit 8652f38d authored by maximneumann's avatar maximneumann Committed by Sergio Guadarrama
Browse files

Add PNASNet model. (#3707)

parent 34e79348
......@@ -205,6 +205,7 @@ py_library(
":nasnet",
":overfeat",
":pix2pix",
":pnasnet",
":resnet_v1",
":resnet_v2",
":vgg",
......@@ -533,6 +534,29 @@ py_test(
],
)
py_library(
name = "pnasnet",
srcs = ["nets/nasnet/pnasnet.py"],
srcs_version = "PY2AND3",
deps = [
":nasnet",
":nasnet_utils",
# "//tensorflow",
],
)
py_test(
name = "pnasnet_test",
size = "large",
srcs = ["nets/nasnet/pnasnet_test.py"],
shard_count = 4,
srcs_version = "PY2AND3",
deps = [
":pnasnet",
# "//tensorflow",
],
)
py_library(
name = "overfeat",
srcs = ["nets/overfeat.py"],
......
......@@ -262,6 +262,7 @@ Model | TF-Slim File | Checkpoint | Top-1 Accuracy| Top-5 Accuracy |
[MobileNet_v2_1.0_224^*](https://arxiv.org/abs/1801.04381)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py)|[Checkpoint TBA]()|72.2|91.0|
[NASNet-A_Mobile_224](https://arxiv.org/abs/1707.07012)#|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/nasnet/nasnet.py)|[nasnet-a_mobile_04_10_2017.tar.gz](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz)|74.0|91.6|
[NASNet-A_Large_331](https://arxiv.org/abs/1707.07012)#|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/nasnet/nasnet.py)|[nasnet-a_large_04_10_2017.tar.gz](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_large_04_10_2017.tar.gz)|82.7|96.2|
[PNASNet-5_Large_331](https://arxiv.org/abs/1712.00559)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/nasnet/pnasnet.py)|[pnasnet-5_large_2017_12_13.tar.gz](https://storage.googleapis.com/download.tensorflow.org/models/pnasnet-5_large_2017_12_13.tar.gz)|82.9|96.2|
^ ResNet V2 models use Inception pre-processing and input image size of 299 (use
`--preprocessing_name inception --eval_image_size 299` when using
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains the definition for the PNASNet classification networks.
Paper: https://arxiv.org/abs/1712.00559
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import tensorflow as tf
from nets.nasnet import nasnet
from nets.nasnet import nasnet_utils
arg_scope = tf.contrib.framework.arg_scope
slim = tf.contrib.slim
def large_imagenet_config():
"""Large ImageNet configuration based on PNASNet-5."""
return tf.contrib.training.HParams(
stem_multiplier=3.0,
dense_dropout_keep_prob=0.5,
num_cells=12,
filter_scaling_rate=2.0,
num_conv_filters=216,
drop_path_keep_prob=0.6,
use_aux_head=1,
num_reduction_layers=2,
data_format='NHWC',
total_training_steps=250000,
)
def pnasnet_large_arg_scope(weight_decay=4e-5, batch_norm_decay=0.9997,
batch_norm_epsilon=0.001):
"""Default arg scope for the PNASNet Large ImageNet model."""
return nasnet.nasnet_large_arg_scope(
weight_decay, batch_norm_decay, batch_norm_epsilon)
def _build_pnasnet_base(images,
normal_cell,
num_classes,
hparams,
is_training,
final_endpoint=None):
"""Constructs a PNASNet image model."""
end_points = {}
def add_and_check_endpoint(endpoint_name, net):
end_points[endpoint_name] = net
return final_endpoint and (endpoint_name == final_endpoint)
# Find where to place the reduction cells or stride normal cells
reduction_indices = nasnet_utils.calc_reduction_layers(
hparams.num_cells, hparams.num_reduction_layers)
# pylint: disable=protected-access
stem = lambda: nasnet._imagenet_stem(images, hparams, normal_cell)
# pylint: enable=protected-access
net, cell_outputs = stem()
if add_and_check_endpoint('Stem', net):
return net, end_points
# Setup for building in the auxiliary head.
aux_head_cell_idxes = []
if len(reduction_indices) >= 2:
aux_head_cell_idxes.append(reduction_indices[1] - 1)
# Run the cells
filter_scaling = 1.0
# true_cell_num accounts for the stem cells
true_cell_num = 2
for cell_num in range(hparams.num_cells):
is_reduction = cell_num in reduction_indices
stride = 2 if is_reduction else 1
if is_reduction: filter_scaling *= hparams.filter_scaling_rate
prev_layer = cell_outputs[-2]
net = normal_cell(
net,
scope='cell_{}'.format(cell_num),
filter_scaling=filter_scaling,
stride=stride,
prev_layer=prev_layer,
cell_num=true_cell_num)
if add_and_check_endpoint('Cell_{}'.format(cell_num), net):
return net, end_points
true_cell_num += 1
cell_outputs.append(net)
if (hparams.use_aux_head and cell_num in aux_head_cell_idxes and
num_classes and is_training):
aux_net = tf.nn.relu(net)
# pylint: disable=protected-access
nasnet._build_aux_head(aux_net, end_points, num_classes, hparams,
scope='aux_{}'.format(cell_num))
# pylint: enable=protected-access
# Final softmax layer
with tf.variable_scope('final_layer'):
net = tf.nn.relu(net)
net = nasnet_utils.global_avg_pool(net)
if add_and_check_endpoint('global_pool', net) or not num_classes:
return net, end_points
net = slim.dropout(net, hparams.dense_dropout_keep_prob, scope='dropout')
logits = slim.fully_connected(net, num_classes)
if add_and_check_endpoint('Logits', logits):
return net, end_points
predictions = tf.nn.softmax(logits, name='predictions')
if add_and_check_endpoint('Predictions', predictions):
return net, end_points
return logits, end_points
def build_pnasnet_large(images,
num_classes,
is_training=True,
final_endpoint=None,
config=None):
"""Build PNASNet Large model for the ImageNet Dataset."""
hparams = copy.deepcopy(config) if config else large_imagenet_config()
# pylint: disable=protected-access
nasnet._update_hparams(hparams, is_training)
# pylint: enable=protected-access
if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
tf.logging.info('A GPU is available on the machine, consider using NCHW '
'data format for increased speed on GPU.')
if hparams.data_format == 'NCHW':
images = tf.transpose(images, [0, 3, 1, 2])
# Calculate the total number of cells in the network.
# There is no distinction between reduction and normal cells in PNAS so the
# total number of cells is equal to the number normal cells plus the number
# of stem cells (two by default).
total_num_cells = hparams.num_cells + 2
normal_cell = PNasNetNormalCell(hparams.num_conv_filters,
hparams.drop_path_keep_prob, total_num_cells,
hparams.total_training_steps)
with arg_scope(
[slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
is_training=is_training):
with arg_scope([slim.avg_pool2d, slim.max_pool2d, slim.conv2d,
slim.batch_norm, slim.separable_conv2d,
nasnet_utils.factorized_reduction,
nasnet_utils.global_avg_pool,
nasnet_utils.get_channel_index,
nasnet_utils.get_channel_dim],
data_format=hparams.data_format):
return _build_pnasnet_base(
images,
normal_cell=normal_cell,
num_classes=num_classes,
hparams=hparams,
is_training=is_training,
final_endpoint=final_endpoint)
build_pnasnet_large.default_image_size = 331
class PNasNetNormalCell(nasnet_utils.NasNetABaseCell):
"""PNASNet Normal Cell."""
def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells,
total_training_steps):
# Configuration for the PNASNet-5 model.
operations = [
'separable_5x5_2', 'max_pool_3x3', 'separable_7x7_2', 'max_pool_3x3',
'separable_5x5_2', 'separable_3x3_2', 'separable_3x3_2', 'max_pool_3x3',
'separable_3x3_2', 'none'
]
used_hiddenstates = [1, 1, 0, 0, 0, 0, 0]
hiddenstate_indices = [1, 1, 0, 0, 0, 0, 4, 0, 1, 0]
super(PNasNetNormalCell, self).__init__(
num_conv_filters, operations, used_hiddenstates, hiddenstate_indices,
drop_path_keep_prob, total_num_cells, total_training_steps)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for slim.pnasnet."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from nets.nasnet import pnasnet
slim = tf.contrib.slim
class PNASNetTest(tf.test.TestCase):
def testBuildLogitsLargeModel(self):
batch_size = 5
height, width = 331, 331
num_classes = 1000
inputs = tf.random_uniform((batch_size, height, width, 3))
tf.train.create_global_step()
with slim.arg_scope(pnasnet.pnasnet_large_arg_scope()):
logits, end_points = pnasnet.build_pnasnet_large(inputs, num_classes)
auxlogits = end_points['AuxLogits']
predictions = end_points['Predictions']
self.assertListEqual(auxlogits.get_shape().as_list(),
[batch_size, num_classes])
self.assertListEqual(logits.get_shape().as_list(),
[batch_size, num_classes])
self.assertListEqual(predictions.get_shape().as_list(),
[batch_size, num_classes])
def testBuildPreLogitsLargeModel(self):
batch_size = 5
height, width = 331, 331
num_classes = None
inputs = tf.random_uniform((batch_size, height, width, 3))
tf.train.create_global_step()
with slim.arg_scope(pnasnet.pnasnet_large_arg_scope()):
net, end_points = pnasnet.build_pnasnet_large(inputs, num_classes)
self.assertFalse('AuxLogits' in end_points)
self.assertFalse('Predictions' in end_points)
self.assertTrue(net.op.name.startswith('final_layer/Mean'))
self.assertListEqual(net.get_shape().as_list(), [batch_size, 4320])
def testAllEndPointsShapesLargeModel(self):
batch_size = 5
height, width = 331, 331
num_classes = 1000
inputs = tf.random_uniform((batch_size, height, width, 3))
tf.train.create_global_step()
with slim.arg_scope(pnasnet.pnasnet_large_arg_scope()):
_, end_points = pnasnet.build_pnasnet_large(inputs, num_classes)
endpoints_shapes = {'Stem': [batch_size, 42, 42, 540],
'Cell_0': [batch_size, 42, 42, 1080],
'Cell_1': [batch_size, 42, 42, 1080],
'Cell_2': [batch_size, 42, 42, 1080],
'Cell_3': [batch_size, 42, 42, 1080],
'Cell_4': [batch_size, 21, 21, 2160],
'Cell_5': [batch_size, 21, 21, 2160],
'Cell_6': [batch_size, 21, 21, 2160],
'Cell_7': [batch_size, 21, 21, 2160],
'Cell_8': [batch_size, 11, 11, 4320],
'Cell_9': [batch_size, 11, 11, 4320],
'Cell_10': [batch_size, 11, 11, 4320],
'Cell_11': [batch_size, 11, 11, 4320],
'global_pool': [batch_size, 4320],
# Logits and predictions
'AuxLogits': [batch_size, 1000],
'Predictions': [batch_size, 1000],
'Logits': [batch_size, 1000],
}
self.assertEqual(len(end_points), 17)
self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
for endpoint_name in endpoints_shapes:
tf.logging.info('Endpoint name: {}'.format(endpoint_name))
expected_shape = endpoints_shapes[endpoint_name]
self.assertIn(endpoint_name, end_points)
self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
expected_shape)
def testNoAuxHeadLargeModel(self):
batch_size = 5
height, width = 331, 331
num_classes = 1000
for use_aux_head in (True, False):
tf.reset_default_graph()
inputs = tf.random_uniform((batch_size, height, width, 3))
tf.train.create_global_step()
config = pnasnet.large_imagenet_config()
config.set_hparam('use_aux_head', int(use_aux_head))
with slim.arg_scope(pnasnet.pnasnet_large_arg_scope()):
_, end_points = pnasnet.build_pnasnet_large(inputs, num_classes,
config=config)
self.assertEqual('AuxLogits' in end_points, use_aux_head)
def testOverrideHParamsLargeModel(self):
batch_size = 5
height, width = 331, 331
num_classes = 1000
inputs = tf.random_uniform((batch_size, height, width, 3))
tf.train.create_global_step()
config = pnasnet.large_imagenet_config()
config.set_hparam('data_format', 'NCHW')
with slim.arg_scope(pnasnet.pnasnet_large_arg_scope()):
_, end_points = pnasnet.build_pnasnet_large(
inputs, num_classes, config=config)
self.assertListEqual(
end_points['Stem'].shape.as_list(), [batch_size, 540, 42, 42])
if __name__ == '__main__':
tf.test.main()
......@@ -32,6 +32,7 @@ from nets import resnet_v2
from nets import vgg
from nets.mobilenet import mobilenet_v2
from nets.nasnet import nasnet
from nets.nasnet import pnasnet
slim = tf.contrib.slim
......@@ -63,6 +64,7 @@ networks_map = {'alexnet_v2': alexnet.alexnet_v2,
'nasnet_cifar': nasnet.build_nasnet_cifar,
'nasnet_mobile': nasnet.build_nasnet_mobile,
'nasnet_large': nasnet.build_nasnet_large,
'pnasnet_large': pnasnet.build_pnasnet_large,
}
arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope,
......@@ -94,6 +96,7 @@ arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope,
'nasnet_cifar': nasnet.nasnet_cifar_arg_scope,
'nasnet_mobile': nasnet.nasnet_mobile_arg_scope,
'nasnet_large': nasnet.nasnet_large_arg_scope,
'pnasnet_large': pnasnet.pnasnet_large_arg_scope,
}
......
......@@ -56,6 +56,7 @@ def get_preprocessing(name, is_training=False):
'mobilenet_v1': inception_preprocessing,
'nasnet_mobile': inception_preprocessing,
'nasnet_large': inception_preprocessing,
'pnasnet_large': inception_preprocessing,
'resnet_v1_50': vgg_preprocessing,
'resnet_v1_101': vgg_preprocessing,
'resnet_v1_152': vgg_preprocessing,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment