Commit 2c962110 authored by huihui-personal's avatar huihui-personal Committed by aquariusjay
Browse files

Open source deeplab changes. Check change log in README.md for detailed info. (#6231)

* \nRefactor deeplab to use MonitoredTrainingSession\n

PiperOrigin-RevId: 234237190

* Update export_model.py

* Update nas_cell.py

* Update nas_network.py

* Update train.py

* Update deeplab_demo.ipynb

* Update nas_cell.py
parent a432998c
...@@ -124,40 +124,62 @@ tensorflow/models GitHub [issue ...@@ -124,40 +124,62 @@ tensorflow/models GitHub [issue
tracker](https://github.com/tensorflow/models/issues), prefixing the issue name tracker](https://github.com/tensorflow/models/issues), prefixing the issue name
with "deeplab". with "deeplab".
## License
All the codes in deeplab folder is covered by the [LICENSE](https://github.com/tensorflow/models/blob/master/LICENSE)
under tensorflow/models. Please refer to the LICENSE for details.
## Change Logs ## Change Logs
### February 6, 2019
* Update decoder module to exploit multiple low-level features with different
output_strides.
### December 3, 2018
* Released the MobileNet-v2 checkpoint on ADE20K.
### November 19, 2018
* Supported NAS architecture for feature extraction. **Contributor**: Chenxi Liu.
* Supported hard pixel mining during training.
### October 1, 2018 ### October 1, 2018
Released MobileNet-v2 depth-multiplier = 0.5 COCO-pretrained checkpoints on * Released MobileNet-v2 depth-multiplier = 0.5 COCO-pretrained checkpoints on
PASCAL VOC 2012, and Xception-65 COCO pretrained checkpoint (i.e., no PASCAL PASCAL VOC 2012, and Xception-65 COCO pretrained checkpoint (i.e., no PASCAL
pretrained). pretrained).
### September 5, 2018 ### September 5, 2018
Released Cityscapes pretrained checkpoints with found best dense prediction cell. * Released Cityscapes pretrained checkpoints with found best dense prediction cell.
### May 26, 2018 ### May 26, 2018
Updated ADE20K pretrained checkpoint. * Updated ADE20K pretrained checkpoint.
### May 18, 2018 ### May 18, 2018
1. Added builders for ResNet-v1 and Xception model variants. * Added builders for ResNet-v1 and Xception model variants.
1. Added ADE20K support, including colormap and pretrained Xception_65 checkpoint. * Added ADE20K support, including colormap and pretrained Xception_65 checkpoint.
1. Fixed a bug on using non-default depth_multiplier for MobileNet-v2. * Fixed a bug on using non-default depth_multiplier for MobileNet-v2.
### March 22, 2018 ### March 22, 2018
Released checkpoints using MobileNet-V2 as network backbone and pretrained on * Released checkpoints using MobileNet-V2 as network backbone and pretrained on
PASCAL VOC 2012 and Cityscapes. PASCAL VOC 2012 and Cityscapes.
### March 5, 2018 ### March 5, 2018
First release of DeepLab in TensorFlow including deeper Xception network * First release of DeepLab in TensorFlow including deeper Xception network
backbone. Included chekcpoints that have been pretrained on PASCAL VOC 2012 backbone. Included chekcpoints that have been pretrained on PASCAL VOC 2012
and Cityscapes. and Cityscapes.
......
...@@ -19,7 +19,6 @@ Common flags from train/eval/vis/export_model.py are collected in this script. ...@@ -19,7 +19,6 @@ Common flags from train/eval/vis/export_model.py are collected in this script.
import collections import collections
import copy import copy
import json import json
import tensorflow as tf import tensorflow as tf
flags = tf.app.flags flags = tf.app.flags
...@@ -53,12 +52,16 @@ flags.DEFINE_multi_float('image_pyramid', None, ...@@ -53,12 +52,16 @@ flags.DEFINE_multi_float('image_pyramid', None,
flags.DEFINE_boolean('add_image_level_feature', True, flags.DEFINE_boolean('add_image_level_feature', True,
'Add image level feature.') 'Add image level feature.')
flags.DEFINE_multi_integer( flags.DEFINE_list(
'image_pooling_crop_size', None, 'image_pooling_crop_size', None,
'Image pooling crop size [height, width] used in the ASPP module. When ' 'Image pooling crop size [height, width] used in the ASPP module. When '
'value is None, the model performs image pooling with "crop_size". This' 'value is None, the model performs image pooling with "crop_size". This'
'flag is useful when one likes to use different image pooling sizes.') 'flag is useful when one likes to use different image pooling sizes.')
flags.DEFINE_list(
'image_pooling_stride', '1,1',
'Image pooling stride [height, width] used in the ASPP image pooling. ')
flags.DEFINE_boolean('aspp_with_batch_norm', True, flags.DEFINE_boolean('aspp_with_batch_norm', True,
'Use batch norm parameters for ASPP or not.') 'Use batch norm parameters for ASPP or not.')
...@@ -74,11 +77,18 @@ flags.DEFINE_float('depth_multiplier', 1.0, ...@@ -74,11 +77,18 @@ flags.DEFINE_float('depth_multiplier', 1.0,
'Multiplier for the depth (number of channels) for all ' 'Multiplier for the depth (number of channels) for all '
'convolution ops used in MobileNet.') 'convolution ops used in MobileNet.')
flags.DEFINE_integer('divisible_by', None,
'An integer that ensures the layer # channels are '
'divisible by this value. Used in MobileNet.')
# For `xception_65`, use decoder_output_stride = 4. For `mobilenet_v2`, use # For `xception_65`, use decoder_output_stride = 4. For `mobilenet_v2`, use
# decoder_output_stride = None. # decoder_output_stride = None.
flags.DEFINE_integer('decoder_output_stride', None, flags.DEFINE_list('decoder_output_stride', None,
'The ratio of input to output spatial resolution when ' 'Comma-separated list of strings with the number specifying '
'employing decoder to refine segmentation results.') 'output stride of low-level features at each network level.'
'Current semantic segmentation implementation assumes at '
'most one output stride (i.e., either None or a list with '
'only one element.')
flags.DEFINE_boolean('decoder_use_separable_conv', True, flags.DEFINE_boolean('decoder_use_separable_conv', True,
'Employ separable convolution for decoder or not.') 'Employ separable convolution for decoder or not.')
...@@ -86,11 +96,28 @@ flags.DEFINE_boolean('decoder_use_separable_conv', True, ...@@ -86,11 +96,28 @@ flags.DEFINE_boolean('decoder_use_separable_conv', True,
flags.DEFINE_enum('merge_method', 'max', ['max', 'avg'], flags.DEFINE_enum('merge_method', 'max', ['max', 'avg'],
'Scheme to merge multi scale features.') 'Scheme to merge multi scale features.')
flags.DEFINE_boolean(
'prediction_with_upsampled_logits', True,
'When performing prediction, there are two options: (1) bilinear '
'upsampling the logits followed by argmax, or (2) armax followed by '
'nearest upsampling the predicted labels. The second option may introduce '
'some "blocking effect", but it is more computationally efficient. '
'Currently, prediction_with_upsampled_logits=False is only supported for '
'single-scale inference.')
flags.DEFINE_string( flags.DEFINE_string(
'dense_prediction_cell_json', 'dense_prediction_cell_json',
'', '',
'A JSON file that specifies the dense prediction cell.') 'A JSON file that specifies the dense prediction cell.')
flags.DEFINE_integer(
'nas_stem_output_num_conv_filters', 20,
'Number of filters of the stem output tensor in NAS models.')
flags.DEFINE_bool('use_bounded_activation', False,
'Whether or not to use bounded activations. Bounded '
'activations better lend themselves to quantized inference.')
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
# Constants # Constants
...@@ -117,9 +144,11 @@ class ModelOptions( ...@@ -117,9 +144,11 @@ class ModelOptions(
'crop_size', 'crop_size',
'atrous_rates', 'atrous_rates',
'output_stride', 'output_stride',
'preprocessed_images_dtype',
'merge_method', 'merge_method',
'add_image_level_feature', 'add_image_level_feature',
'image_pooling_crop_size', 'image_pooling_crop_size',
'image_pooling_stride',
'aspp_with_batch_norm', 'aspp_with_batch_norm',
'aspp_with_separable_conv', 'aspp_with_separable_conv',
'multi_grid', 'multi_grid',
...@@ -128,7 +157,11 @@ class ModelOptions( ...@@ -128,7 +157,11 @@ class ModelOptions(
'logits_kernel_size', 'logits_kernel_size',
'model_variant', 'model_variant',
'depth_multiplier', 'depth_multiplier',
'divisible_by',
'prediction_with_upsampled_logits',
'dense_prediction_cell_config', 'dense_prediction_cell_config',
'nas_stem_output_num_conv_filters',
'use_bounded_activation'
])): ])):
"""Immutable class to hold model options.""" """Immutable class to hold model options."""
...@@ -138,7 +171,8 @@ class ModelOptions( ...@@ -138,7 +171,8 @@ class ModelOptions(
outputs_to_num_classes, outputs_to_num_classes,
crop_size=None, crop_size=None,
atrous_rates=None, atrous_rates=None,
output_stride=8): output_stride=8,
preprocessed_images_dtype=tf.float32):
"""Constructor to set default values. """Constructor to set default values.
Args: Args:
...@@ -148,6 +182,7 @@ class ModelOptions( ...@@ -148,6 +182,7 @@ class ModelOptions(
crop_size: A tuple [crop_height, crop_width]. crop_size: A tuple [crop_height, crop_width].
atrous_rates: A list of atrous convolution rates for ASPP. atrous_rates: A list of atrous convolution rates for ASPP.
output_stride: The ratio of input to output spatial resolution. output_stride: The ratio of input to output spatial resolution.
preprocessed_images_dtype: The type after the preprocessing function.
Returns: Returns:
A new ModelOptions instance. A new ModelOptions instance.
...@@ -156,18 +191,35 @@ class ModelOptions( ...@@ -156,18 +191,35 @@ class ModelOptions(
if FLAGS.dense_prediction_cell_json: if FLAGS.dense_prediction_cell_json:
with tf.gfile.Open(FLAGS.dense_prediction_cell_json, 'r') as f: with tf.gfile.Open(FLAGS.dense_prediction_cell_json, 'r') as f:
dense_prediction_cell_config = json.load(f) dense_prediction_cell_config = json.load(f)
decoder_output_stride = None
if FLAGS.decoder_output_stride:
decoder_output_stride = [
int(x) for x in FLAGS.decoder_output_stride]
if sorted(decoder_output_stride, reverse=True) != decoder_output_stride:
raise ValueError('Decoder output stride need to be sorted in the '
'descending order.')
image_pooling_crop_size = None
if FLAGS.image_pooling_crop_size:
image_pooling_crop_size = [int(x) for x in FLAGS.image_pooling_crop_size]
image_pooling_stride = [1, 1]
if FLAGS.image_pooling_stride:
image_pooling_stride = [int(x) for x in FLAGS.image_pooling_stride]
return super(ModelOptions, cls).__new__( return super(ModelOptions, cls).__new__(
cls, outputs_to_num_classes, crop_size, atrous_rates, output_stride, cls, outputs_to_num_classes, crop_size, atrous_rates, output_stride,
FLAGS.merge_method, FLAGS.add_image_level_feature, preprocessed_images_dtype, FLAGS.merge_method,
FLAGS.image_pooling_crop_size, FLAGS.aspp_with_batch_norm, FLAGS.add_image_level_feature,
FLAGS.aspp_with_separable_conv, FLAGS.multi_grid, image_pooling_crop_size,
FLAGS.decoder_output_stride, FLAGS.decoder_use_separable_conv, image_pooling_stride,
FLAGS.logits_kernel_size, FLAGS.model_variant, FLAGS.depth_multiplier, FLAGS.aspp_with_batch_norm,
dense_prediction_cell_config) FLAGS.aspp_with_separable_conv, FLAGS.multi_grid, decoder_output_stride,
FLAGS.decoder_use_separable_conv, FLAGS.logits_kernel_size,
FLAGS.model_variant, FLAGS.depth_multiplier, FLAGS.divisible_by,
FLAGS.prediction_with_upsampled_logits, dense_prediction_cell_config,
FLAGS.nas_stem_output_num_conv_filters, FLAGS.use_bounded_activation)
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
return ModelOptions(copy.deepcopy(self.outputs_to_num_classes), return ModelOptions(copy.deepcopy(self.outputs_to_num_classes),
self.crop_size, self.crop_size,
self.atrous_rates, self.atrous_rates,
self.output_stride) self.output_stride,
self.preprocessed_images_dtype)
...@@ -64,12 +64,12 @@ def dense_prediction_cell_hparams(): ...@@ -64,12 +64,12 @@ def dense_prediction_cell_hparams():
to 8, we need to double the convolution rates correspondingly. to 8, we need to double the convolution rates correspondingly.
""" """
return { return {
'reduction_size': 256, 'reduction_size': 256,
'dropout_on_concat_features': True, 'dropout_on_concat_features': True,
'dropout_on_projection_features': False, 'dropout_on_projection_features': False,
'dropout_keep_prob': 0.9, 'dropout_keep_prob': 0.9,
'concat_channels': 256, 'concat_channels': 256,
'conv_rate_multiplier': 1, 'conv_rate_multiplier': 1,
} }
...@@ -127,7 +127,7 @@ class DensePredictionCell(object): ...@@ -127,7 +127,7 @@ class DensePredictionCell(object):
return ([resize_height, resize_width], [pooled_height, pooled_width]) return ([resize_height, resize_width], [pooled_height, pooled_width])
def _parse_operation(self, config, crop_size, output_stride, def _parse_operation(self, config, crop_size, output_stride,
image_pooling_crop_size=None): image_pooling_crop_size=None):
"""Parses one operation. """Parses one operation.
When 'operation' is 'pyramid_pooling', we compute the required When 'operation' is 'pyramid_pooling', we compute the required
...@@ -150,23 +150,23 @@ class DensePredictionCell(object): ...@@ -150,23 +150,23 @@ class DensePredictionCell(object):
if config[_OP] == _PYRAMID_POOLING: if config[_OP] == _PYRAMID_POOLING:
(config[_TARGET_SIZE], (config[_TARGET_SIZE],
config[_KERNEL]) = self._get_pyramid_pooling_arguments( config[_KERNEL]) = self._get_pyramid_pooling_arguments(
crop_size=crop_size, crop_size=crop_size,
output_stride=output_stride, output_stride=output_stride,
image_grid=config[_GRID_SIZE], image_grid=config[_GRID_SIZE],
image_pooling_crop_size=image_pooling_crop_size) image_pooling_crop_size=image_pooling_crop_size)
return config return config
def build_cell(self, def build_cell(self,
features, features,
output_stride=16, output_stride=16,
crop_size=None, crop_size=None,
image_pooling_crop_size=None, image_pooling_crop_size=None,
weight_decay=0.00004, weight_decay=0.00004,
reuse=None, reuse=None,
is_training=False, is_training=False,
fine_tune_batch_norm=False, fine_tune_batch_norm=False,
scope=None): scope=None):
"""Builds the dense prediction cell based on the config. """Builds the dense prediction cell based on the config.
Args: Args:
...@@ -194,10 +194,10 @@ class DensePredictionCell(object): ...@@ -194,10 +194,10 @@ class DensePredictionCell(object):
the operation is not recognized. the operation is not recognized.
""" """
batch_norm_params = { batch_norm_params = {
'is_training': is_training and fine_tune_batch_norm, 'is_training': is_training and fine_tune_batch_norm,
'decay': 0.9997, 'decay': 0.9997,
'epsilon': 1e-5, 'epsilon': 1e-5,
'scale': True, 'scale': True,
} }
hparams = self.hparams hparams = self.hparams
with slim.arg_scope( with slim.arg_scope(
...@@ -226,7 +226,7 @@ class DensePredictionCell(object): ...@@ -226,7 +226,7 @@ class DensePredictionCell(object):
operation_input = branch_logits[current_config[_INPUT]] operation_input = branch_logits[current_config[_INPUT]]
if current_config[_OP] == _CONV: if current_config[_OP] == _CONV:
if current_config[_KERNEL] == [1, 1] or current_config[ if current_config[_KERNEL] == [1, 1] or current_config[
_KERNEL] == 1: _KERNEL] == 1:
branch_logits.append( branch_logits.append(
slim.conv2d(operation_input, depth, 1, scope=scope)) slim.conv2d(operation_input, depth, 1, scope=scope))
else: else:
...@@ -285,4 +285,4 @@ class DensePredictionCell(object): ...@@ -285,4 +285,4 @@ class DensePredictionCell(object):
keep_prob=self.hparams['dropout_keep_prob'], keep_prob=self.hparams['dropout_keep_prob'],
is_training=is_training, is_training=is_training,
scope=_CONCAT_PROJECTION_SCOPE + '_dropout') scope=_CONCAT_PROJECTION_SCOPE + '_dropout')
return concat_logits return concat_logits
\ No newline at end of file
...@@ -29,49 +29,49 @@ class DensePredictionCellTest(tf.test.TestCase): ...@@ -29,49 +29,49 @@ class DensePredictionCellTest(tf.test.TestCase):
def setUp(self): def setUp(self):
self.segmentation_layer = dense_prediction_cell.DensePredictionCell( self.segmentation_layer = dense_prediction_cell.DensePredictionCell(
config=[ config=[
{ {
dense_prediction_cell._INPUT: -1, dense_prediction_cell._INPUT: -1,
dense_prediction_cell._OP: dense_prediction_cell._CONV, dense_prediction_cell._OP: dense_prediction_cell._CONV,
dense_prediction_cell._KERNEL: 1, dense_prediction_cell._KERNEL: 1,
}, },
{ {
dense_prediction_cell._INPUT: 0, dense_prediction_cell._INPUT: 0,
dense_prediction_cell._OP: dense_prediction_cell._CONV, dense_prediction_cell._OP: dense_prediction_cell._CONV,
dense_prediction_cell._KERNEL: 3, dense_prediction_cell._KERNEL: 3,
dense_prediction_cell._RATE: [1, 3], dense_prediction_cell._RATE: [1, 3],
}, },
{ {
dense_prediction_cell._INPUT: 1, dense_prediction_cell._INPUT: 1,
dense_prediction_cell._OP: ( dense_prediction_cell._OP: (
dense_prediction_cell._PYRAMID_POOLING), dense_prediction_cell._PYRAMID_POOLING),
dense_prediction_cell._GRID_SIZE: [1, 2], dense_prediction_cell._GRID_SIZE: [1, 2],
}, },
], ],
hparams={'conv_rate_multiplier': 2}) hparams={'conv_rate_multiplier': 2})
def testPyramidPoolingArguments(self): def testPyramidPoolingArguments(self):
features_size, pooled_kernel = ( features_size, pooled_kernel = (
self.segmentation_layer._get_pyramid_pooling_arguments( self.segmentation_layer._get_pyramid_pooling_arguments(
crop_size=[513, 513], crop_size=[513, 513],
output_stride=16, output_stride=16,
image_grid=[4, 4])) image_grid=[4, 4]))
self.assertListEqual(features_size, [33, 33]) self.assertListEqual(features_size, [33, 33])
self.assertListEqual(pooled_kernel, [9, 9]) self.assertListEqual(pooled_kernel, [9, 9])
def testPyramidPoolingArgumentsWithImageGrid1x1(self): def testPyramidPoolingArgumentsWithImageGrid1x1(self):
features_size, pooled_kernel = ( features_size, pooled_kernel = (
self.segmentation_layer._get_pyramid_pooling_arguments( self.segmentation_layer._get_pyramid_pooling_arguments(
crop_size=[257, 257], crop_size=[257, 257],
output_stride=16, output_stride=16,
image_grid=[1, 1])) image_grid=[1, 1]))
self.assertListEqual(features_size, [17, 17]) self.assertListEqual(features_size, [17, 17])
self.assertListEqual(pooled_kernel, [17, 17]) self.assertListEqual(pooled_kernel, [17, 17])
def testParseOperationStringWithConv1x1(self): def testParseOperationStringWithConv1x1(self):
operation = self.segmentation_layer._parse_operation( operation = self.segmentation_layer._parse_operation(
config={ config={
dense_prediction_cell._OP: dense_prediction_cell._CONV, dense_prediction_cell._OP: dense_prediction_cell._CONV,
dense_prediction_cell._KERNEL: [1, 1], dense_prediction_cell._KERNEL: [1, 1],
}, },
crop_size=[513, 513], output_stride=16) crop_size=[513, 513], output_stride=16)
self.assertEqual(operation[dense_prediction_cell._OP], self.assertEqual(operation[dense_prediction_cell._OP],
...@@ -81,9 +81,9 @@ class DensePredictionCellTest(tf.test.TestCase): ...@@ -81,9 +81,9 @@ class DensePredictionCellTest(tf.test.TestCase):
def testParseOperationStringWithConv3x3(self): def testParseOperationStringWithConv3x3(self):
operation = self.segmentation_layer._parse_operation( operation = self.segmentation_layer._parse_operation(
config={ config={
dense_prediction_cell._OP: dense_prediction_cell._CONV, dense_prediction_cell._OP: dense_prediction_cell._CONV,
dense_prediction_cell._KERNEL: [3, 3], dense_prediction_cell._KERNEL: [3, 3],
dense_prediction_cell._RATE: [9, 6], dense_prediction_cell._RATE: [9, 6],
}, },
crop_size=[513, 513], output_stride=16) crop_size=[513, 513], output_stride=16)
self.assertEqual(operation[dense_prediction_cell._OP], self.assertEqual(operation[dense_prediction_cell._OP],
...@@ -94,8 +94,8 @@ class DensePredictionCellTest(tf.test.TestCase): ...@@ -94,8 +94,8 @@ class DensePredictionCellTest(tf.test.TestCase):
def testParseOperationStringWithPyramidPooling2x2(self): def testParseOperationStringWithPyramidPooling2x2(self):
operation = self.segmentation_layer._parse_operation( operation = self.segmentation_layer._parse_operation(
config={ config={
dense_prediction_cell._OP: dense_prediction_cell._PYRAMID_POOLING, dense_prediction_cell._OP: dense_prediction_cell._PYRAMID_POOLING,
dense_prediction_cell._GRID_SIZE: [2, 2], dense_prediction_cell._GRID_SIZE: [2, 2],
}, },
crop_size=[513, 513], crop_size=[513, 513],
output_stride=16) output_stride=16)
...@@ -132,4 +132,4 @@ class DensePredictionCellTest(tf.test.TestCase): ...@@ -132,4 +132,4 @@ class DensePredictionCellTest(tf.test.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
\ No newline at end of file
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import functools import functools
import tensorflow as tf import tensorflow as tf
from deeplab.core import nas_network
from deeplab.core import resnet_v1_beta from deeplab.core import resnet_v1_beta
from deeplab.core import xception from deeplab.core import xception
from tensorflow.contrib.slim.nets import resnet_utils from tensorflow.contrib.slim.nets import resnet_utils
...@@ -32,6 +33,7 @@ _MOBILENET_V2_FINAL_ENDPOINT = 'layer_18' ...@@ -32,6 +33,7 @@ _MOBILENET_V2_FINAL_ENDPOINT = 'layer_18'
def _mobilenet_v2(net, def _mobilenet_v2(net,
depth_multiplier, depth_multiplier,
output_stride, output_stride,
divisible_by=None,
reuse=None, reuse=None,
scope=None, scope=None,
final_endpoint=None): final_endpoint=None):
...@@ -48,6 +50,8 @@ def _mobilenet_v2(net, ...@@ -48,6 +50,8 @@ def _mobilenet_v2(net,
if necessary to prevent the network from reducing the spatial resolution if necessary to prevent the network from reducing the spatial resolution
of the activation maps. Allowed values are 8 (accurate fully convolutional of the activation maps. Allowed values are 8 (accurate fully convolutional
mode), 16 (fast fully convolutional mode), 32 (classification mode). mode), 16 (fast fully convolutional mode), 32 (classification mode).
divisible_by: None (use default setting) or an integer that ensures all
layers # channels will be divisible by this number. Used in MobileNet.
reuse: Reuse model variables. reuse: Reuse model variables.
scope: Optional variable scope. scope: Optional variable scope.
final_endpoint: The endpoint to construct the network up to. final_endpoint: The endpoint to construct the network up to.
...@@ -55,6 +59,8 @@ def _mobilenet_v2(net, ...@@ -55,6 +59,8 @@ def _mobilenet_v2(net,
Returns: Returns:
Features extracted by MobileNetv2. Features extracted by MobileNetv2.
""" """
if divisible_by is None:
divisible_by = 8 if depth_multiplier == 1.0 else 1
with tf.variable_scope( with tf.variable_scope(
scope, 'MobilenetV2', [net], reuse=reuse) as scope: scope, 'MobilenetV2', [net], reuse=reuse) as scope:
return mobilenet_v2.mobilenet_base( return mobilenet_v2.mobilenet_base(
...@@ -62,7 +68,7 @@ def _mobilenet_v2(net, ...@@ -62,7 +68,7 @@ def _mobilenet_v2(net,
conv_defs=mobilenet_v2.V2_DEF, conv_defs=mobilenet_v2.V2_DEF,
depth_multiplier=depth_multiplier, depth_multiplier=depth_multiplier,
min_depth=8 if depth_multiplier == 1.0 else 1, min_depth=8 if depth_multiplier == 1.0 else 1,
divisible_by=8 if depth_multiplier == 1.0 else 1, divisible_by=divisible_by,
final_endpoint=final_endpoint or _MOBILENET_V2_FINAL_ENDPOINT, final_endpoint=final_endpoint or _MOBILENET_V2_FINAL_ENDPOINT,
output_stride=output_stride, output_stride=output_stride,
scope=scope) scope=scope)
...@@ -78,6 +84,8 @@ networks_map = { ...@@ -78,6 +84,8 @@ networks_map = {
'xception_41': xception.xception_41, 'xception_41': xception.xception_41,
'xception_65': xception.xception_65, 'xception_65': xception.xception_65,
'xception_71': xception.xception_71, 'xception_71': xception.xception_71,
'nas_pnasnet': nas_network.pnasnet,
'nas_hnasnet': nas_network.hnasnet,
} }
# A map from network name to network arg scope. # A map from network name to network arg scope.
...@@ -90,6 +98,8 @@ arg_scopes_map = { ...@@ -90,6 +98,8 @@ arg_scopes_map = {
'xception_41': xception.xception_arg_scope, 'xception_41': xception.xception_arg_scope,
'xception_65': xception.xception_arg_scope, 'xception_65': xception.xception_arg_scope,
'xception_71': xception.xception_arg_scope, 'xception_71': xception.xception_arg_scope,
'nas_pnasnet': nas_network.nas_arg_scope,
'nas_hnasnet': nas_network.nas_arg_scope,
} }
# Names for end point features. # Names for end point features.
...@@ -98,37 +108,57 @@ DECODER_END_POINTS = 'decoder_end_points' ...@@ -98,37 +108,57 @@ DECODER_END_POINTS = 'decoder_end_points'
# A dictionary from network name to a map of end point features. # A dictionary from network name to a map of end point features.
networks_to_feature_maps = { networks_to_feature_maps = {
'mobilenet_v2': { 'mobilenet_v2': {
DECODER_END_POINTS: ['layer_4/depthwise_output'], DECODER_END_POINTS: {
4: ['layer_4/depthwise_output'],
},
}, },
'resnet_v1_50': { 'resnet_v1_50': {
DECODER_END_POINTS: ['block1/unit_2/bottleneck_v1/conv3'], DECODER_END_POINTS: {
4: ['block1/unit_2/bottleneck_v1/conv3'],
},
}, },
'resnet_v1_50_beta': { 'resnet_v1_50_beta': {
DECODER_END_POINTS: ['block1/unit_2/bottleneck_v1/conv3'], DECODER_END_POINTS: {
4: ['block1/unit_2/bottleneck_v1/conv3'],
},
}, },
'resnet_v1_101': { 'resnet_v1_101': {
DECODER_END_POINTS: ['block1/unit_2/bottleneck_v1/conv3'], DECODER_END_POINTS: {
4: ['block1/unit_2/bottleneck_v1/conv3'],
},
}, },
'resnet_v1_101_beta': { 'resnet_v1_101_beta': {
DECODER_END_POINTS: ['block1/unit_2/bottleneck_v1/conv3'], DECODER_END_POINTS: {
4: ['block1/unit_2/bottleneck_v1/conv3'],
},
}, },
'xception_41': { 'xception_41': {
DECODER_END_POINTS: [ DECODER_END_POINTS: {
'entry_flow/block2/unit_1/xception_module/' 4: ['entry_flow/block2/unit_1/xception_module/'
'separable_conv2_pointwise', 'separable_conv2_pointwise'],
], },
}, },
'xception_65': { 'xception_65': {
DECODER_END_POINTS: [ DECODER_END_POINTS: {
'entry_flow/block2/unit_1/xception_module/' 4: ['entry_flow/block2/unit_1/xception_module/'
'separable_conv2_pointwise', 'separable_conv2_pointwise'],
], },
}, },
'xception_71': { 'xception_71': {
DECODER_END_POINTS: [ DECODER_END_POINTS: {
'entry_flow/block3/unit_1/xception_module/' 4: ['entry_flow/block3/unit_1/xception_module/'
'separable_conv2_pointwise', 'separable_conv2_pointwise'],
], },
},
'nas_pnasnet': {
DECODER_END_POINTS: {
4: ['Stem'],
},
},
'nas_hnasnet': {
DECODER_END_POINTS: {
4: ['Cell_2'],
},
}, },
} }
...@@ -143,21 +173,28 @@ name_scope = { ...@@ -143,21 +173,28 @@ name_scope = {
'xception_41': 'xception_41', 'xception_41': 'xception_41',
'xception_65': 'xception_65', 'xception_65': 'xception_65',
'xception_71': 'xception_71', 'xception_71': 'xception_71',
'nas_pnasnet': 'pnasnet',
'nas_hnasnet': 'hnasnet',
} }
# Mean pixel value. # Mean pixel value.
_MEAN_RGB = [123.15, 115.90, 103.06] _MEAN_RGB = [123.15, 115.90, 103.06]
def _preprocess_subtract_imagenet_mean(inputs): def _preprocess_subtract_imagenet_mean(inputs, dtype=tf.float32):
"""Subtract Imagenet mean RGB value.""" """Subtract Imagenet mean RGB value."""
mean_rgb = tf.reshape(_MEAN_RGB, [1, 1, 1, 3]) mean_rgb = tf.reshape(_MEAN_RGB, [1, 1, 1, 3])
return inputs - mean_rgb num_channels = tf.shape(inputs)[-1]
# We set mean pixel as 0 for the non-RGB channels.
mean_rgb_extended = tf.concat(
[mean_rgb, tf.zeros([1, 1, 1, num_channels - 3])], axis=3)
return tf.cast(inputs - mean_rgb_extended, dtype=dtype)
def _preprocess_zero_mean_unit_range(inputs): def _preprocess_zero_mean_unit_range(inputs, dtype=tf.float32):
"""Map image values from [0, 255] to [-1, 1].""" """Map image values from [0, 255] to [-1, 1]."""
return (2.0 / 255.0) * tf.to_float(inputs) - 1.0 preprocessed_inputs = (2.0 / 255.0) * tf.to_float(inputs) - 1.0
return tf.cast(preprocessed_inputs, dtype=dtype)
_PREPROCESS_FN = { _PREPROCESS_FN = {
...@@ -169,6 +206,8 @@ _PREPROCESS_FN = { ...@@ -169,6 +206,8 @@ _PREPROCESS_FN = {
'xception_41': _preprocess_zero_mean_unit_range, 'xception_41': _preprocess_zero_mean_unit_range,
'xception_65': _preprocess_zero_mean_unit_range, 'xception_65': _preprocess_zero_mean_unit_range,
'xception_71': _preprocess_zero_mean_unit_range, 'xception_71': _preprocess_zero_mean_unit_range,
'nas_pnasnet': _preprocess_zero_mean_unit_range,
'nas_hnasnet': _preprocess_zero_mean_unit_range,
} }
...@@ -201,6 +240,7 @@ def extract_features(images, ...@@ -201,6 +240,7 @@ def extract_features(images,
output_stride=8, output_stride=8,
multi_grid=None, multi_grid=None,
depth_multiplier=1.0, depth_multiplier=1.0,
divisible_by=None,
final_endpoint=None, final_endpoint=None,
model_variant=None, model_variant=None,
weight_decay=0.0001, weight_decay=0.0001,
...@@ -209,8 +249,12 @@ def extract_features(images, ...@@ -209,8 +249,12 @@ def extract_features(images,
fine_tune_batch_norm=False, fine_tune_batch_norm=False,
regularize_depthwise=False, regularize_depthwise=False,
preprocess_images=True, preprocess_images=True,
preprocessed_images_dtype=tf.float32,
num_classes=None, num_classes=None,
global_pool=False): global_pool=False,
nas_stem_output_num_conv_filters=20,
nas_training_hyper_parameters=None,
use_bounded_activation=False):
"""Extracts features by the particular model_variant. """Extracts features by the particular model_variant.
Args: Args:
...@@ -219,6 +263,8 @@ def extract_features(images, ...@@ -219,6 +263,8 @@ def extract_features(images,
multi_grid: Employ a hierarchy of different atrous rates within network. multi_grid: Employ a hierarchy of different atrous rates within network.
depth_multiplier: Float multiplier for the depth (number of channels) depth_multiplier: Float multiplier for the depth (number of channels)
for all convolution ops used in MobileNet. for all convolution ops used in MobileNet.
divisible_by: None (use default setting) or an integer that ensures all
layers # channels will be divisible by this number. Used in MobileNet.
final_endpoint: The MobileNet endpoint to construct the network up to. final_endpoint: The MobileNet endpoint to construct the network up to.
model_variant: Model variant for feature extraction. model_variant: Model variant for feature extraction.
weight_decay: The weight decay for model variables. weight_decay: The weight decay for model variables.
...@@ -231,10 +277,22 @@ def extract_features(images, ...@@ -231,10 +277,22 @@ def extract_features(images,
True. Set to False if preprocessing will be done by other functions. We True. Set to False if preprocessing will be done by other functions. We
supprot two types of preprocessing: (1) Mean pixel substraction and (2) supprot two types of preprocessing: (1) Mean pixel substraction and (2)
Pixel values normalization to be [-1, 1]. Pixel values normalization to be [-1, 1].
preprocessed_images_dtype: The type after the preprocessing function.
num_classes: Number of classes for image classification task. Defaults num_classes: Number of classes for image classification task. Defaults
to None for dense prediction tasks. to None for dense prediction tasks.
global_pool: Global pooling for image classification task. Defaults to global_pool: Global pooling for image classification task. Defaults to
False, since dense prediction tasks do not use this. False, since dense prediction tasks do not use this.
nas_stem_output_num_conv_filters: Number of filters of the NAS stem output
tensor.
nas_training_hyper_parameters: A dictionary storing hyper-parameters for
training nas models. It is either None or its keys are:
- `drop_path_keep_prob`: Probability to keep each path in the cell when
training.
- `total_training_steps`: Total training steps to help drop path
probability calculation.
use_bounded_activation: Whether or not to use bounded activations. Bounded
activations better lend themselves to quantized inference. Currently,
bounded activation is only used in xception model.
Returns: Returns:
features: A tensor of size [batch, feature_height, feature_width, features: A tensor of size [batch, feature_height, feature_width,
...@@ -253,7 +311,7 @@ def extract_features(images, ...@@ -253,7 +311,7 @@ def extract_features(images,
batch_norm_epsilon=1e-5, batch_norm_epsilon=1e-5,
batch_norm_scale=True) batch_norm_scale=True)
features, end_points = get_network( features, end_points = get_network(
model_variant, preprocess_images, arg_scope)( model_variant, preprocess_images, preprocessed_images_dtype, arg_scope)(
inputs=images, inputs=images,
num_classes=num_classes, num_classes=num_classes,
is_training=(is_training and fine_tune_batch_norm), is_training=(is_training and fine_tune_batch_norm),
...@@ -268,9 +326,10 @@ def extract_features(images, ...@@ -268,9 +326,10 @@ def extract_features(images,
batch_norm_decay=0.9997, batch_norm_decay=0.9997,
batch_norm_epsilon=1e-3, batch_norm_epsilon=1e-3,
batch_norm_scale=True, batch_norm_scale=True,
regularize_depthwise=regularize_depthwise) regularize_depthwise=regularize_depthwise,
use_bounded_activation=use_bounded_activation)
features, end_points = get_network( features, end_points = get_network(
model_variant, preprocess_images, arg_scope)( model_variant, preprocess_images, preprocessed_images_dtype, arg_scope)(
inputs=images, inputs=images,
num_classes=num_classes, num_classes=num_classes,
is_training=(is_training and fine_tune_batch_norm), is_training=(is_training and fine_tune_batch_norm),
...@@ -285,25 +344,44 @@ def extract_features(images, ...@@ -285,25 +344,44 @@ def extract_features(images,
is_training=(is_training and fine_tune_batch_norm), is_training=(is_training and fine_tune_batch_norm),
weight_decay=weight_decay) weight_decay=weight_decay)
features, end_points = get_network( features, end_points = get_network(
model_variant, preprocess_images, arg_scope)( model_variant, preprocess_images, preprocessed_images_dtype, arg_scope)(
inputs=images, inputs=images,
depth_multiplier=depth_multiplier, depth_multiplier=depth_multiplier,
divisible_by=divisible_by,
output_stride=output_stride, output_stride=output_stride,
reuse=reuse, reuse=reuse,
scope=name_scope[model_variant], scope=name_scope[model_variant],
final_endpoint=final_endpoint) final_endpoint=final_endpoint)
elif model_variant.startswith('nas'):
arg_scope = arg_scopes_map[model_variant](
weight_decay=weight_decay,
batch_norm_decay=0.9997,
batch_norm_epsilon=1e-3)
features, end_points = get_network(
model_variant, preprocess_images, preprocessed_images_dtype, arg_scope)(
inputs=images,
num_classes=num_classes,
is_training=(is_training and fine_tune_batch_norm),
global_pool=global_pool,
output_stride=output_stride,
nas_stem_output_num_conv_filters=nas_stem_output_num_conv_filters,
nas_training_hyper_parameters=nas_training_hyper_parameters,
reuse=reuse,
scope=name_scope[model_variant])
else: else:
raise ValueError('Unknown model variant %s.' % model_variant) raise ValueError('Unknown model variant %s.' % model_variant)
return features, end_points return features, end_points
def get_network(network_name, preprocess_images, arg_scope=None): def get_network(network_name, preprocess_images,
preprocessed_images_dtype=tf.float32, arg_scope=None):
"""Gets the network. """Gets the network.
Args: Args:
network_name: Network name. network_name: Network name.
preprocess_images: Preprocesses the images or not. preprocess_images: Preprocesses the images or not.
preprocessed_images_dtype: The type after the preprocessing function.
arg_scope: Optional, arg_scope to build the network. If not provided the arg_scope: Optional, arg_scope to build the network. If not provided the
default arg_scope of the network would be used. default arg_scope of the network would be used.
...@@ -316,8 +394,8 @@ def get_network(network_name, preprocess_images, arg_scope=None): ...@@ -316,8 +394,8 @@ def get_network(network_name, preprocess_images, arg_scope=None):
if network_name not in networks_map: if network_name not in networks_map:
raise ValueError('Unsupported network %s.' % network_name) raise ValueError('Unsupported network %s.' % network_name)
arg_scope = arg_scope or arg_scopes_map[network_name]() arg_scope = arg_scope or arg_scopes_map[network_name]()
def _identity_function(inputs): def _identity_function(inputs, dtype=preprocessed_images_dtype):
return inputs return tf.cast(inputs, dtype=dtype)
if preprocess_images: if preprocess_images:
preprocess_function = _PREPROCESS_FN[network_name] preprocess_function = _PREPROCESS_FN[network_name]
else: else:
...@@ -326,5 +404,6 @@ def get_network(network_name, preprocess_images, arg_scope=None): ...@@ -326,5 +404,6 @@ def get_network(network_name, preprocess_images, arg_scope=None):
@functools.wraps(func) @functools.wraps(func)
def network_fn(inputs, *args, **kwargs): def network_fn(inputs, *args, **kwargs):
with slim.arg_scope(arg_scope): with slim.arg_scope(arg_scope):
return func(preprocess_function(inputs), *args, **kwargs) return func(preprocess_function(inputs, preprocessed_images_dtype),
*args, **kwargs)
return network_fn return network_fn
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Cell structure used by NAS."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from deeplab.core.utils import resize_bilinear
from deeplab.core.utils import scale_dimension
arg_scope = tf.contrib.framework.arg_scope
slim = tf.contrib.slim
class NASBaseCell(object):
"""NASNet Cell class that is used as a 'layer' in image architectures.
See https://arxiv.org/abs/1707.07012 and https://arxiv.org/abs/1712.00559.
Args:
num_conv_filters: The number of filters for each convolution operation.
operations: List of operations that are performed in the NASNet Cell in
order.
used_hiddenstates: Binary array that signals if the hiddenstate was used
within the cell. This is used to determine what outputs of the cell
should be concatenated together.
hiddenstate_indices: Determines what hiddenstates should be combined
together with the specified operations to create the NASNet cell.
"""
def __init__(self, num_conv_filters, operations, used_hiddenstates,
hiddenstate_indices, drop_path_keep_prob, total_num_cells,
total_training_steps):
if len(hiddenstate_indices) != len(operations):
raise ValueError(
'Number of hiddenstate_indices and operations should be the same.')
if len(operations) % 2:
raise ValueError('Number of operations should be even.')
self._num_conv_filters = num_conv_filters
self._operations = operations
self._used_hiddenstates = used_hiddenstates
self._hiddenstate_indices = hiddenstate_indices
self._drop_path_keep_prob = drop_path_keep_prob
self._total_num_cells = total_num_cells
self._total_training_steps = total_training_steps
def __call__(self, net, scope, filter_scaling, stride, prev_layer, cell_num):
"""Runs the conv cell."""
self._cell_num = cell_num
self._filter_scaling = filter_scaling
self._filter_size = int(self._num_conv_filters * filter_scaling)
with tf.variable_scope(scope):
net = self._cell_base(net, prev_layer)
for i in range(len(self._operations) // 2):
with tf.variable_scope('comb_iter_{}'.format(i)):
h1 = net[self._hiddenstate_indices[i * 2]]
h2 = net[self._hiddenstate_indices[i * 2 + 1]]
with tf.variable_scope('left'):
h1 = self._apply_conv_operation(
h1, self._operations[i * 2], stride,
self._hiddenstate_indices[i * 2] < 2)
with tf.variable_scope('right'):
h2 = self._apply_conv_operation(
h2, self._operations[i * 2 + 1], stride,
self._hiddenstate_indices[i * 2 + 1] < 2)
with tf.variable_scope('combine'):
h = h1 + h2
net.append(h)
with tf.variable_scope('cell_output'):
net = self._combine_unused_states(net)
return net
def _cell_base(self, net, prev_layer):
"""Runs the beginning of the conv cell before the chosen ops are run."""
filter_size = self._filter_size
if prev_layer is None:
prev_layer = net
else:
if net.shape[2] != prev_layer.shape[2]:
prev_layer = resize_bilinear(
prev_layer, tf.shape(net)[1:3], prev_layer.dtype)
if filter_size != prev_layer.shape[3]:
prev_layer = tf.nn.relu(prev_layer)
prev_layer = slim.conv2d(prev_layer, filter_size, 1, scope='prev_1x1')
prev_layer = slim.batch_norm(prev_layer, scope='prev_bn')
net = tf.nn.relu(net)
net = slim.conv2d(net, filter_size, 1, scope='1x1')
net = slim.batch_norm(net, scope='beginning_bn')
net = tf.split(axis=3, num_or_size_splits=1, value=net)
net.append(prev_layer)
return net
def _apply_conv_operation(self, net, operation, stride,
is_from_original_input):
"""Applies the predicted conv operation to net."""
if stride > 1 and not is_from_original_input:
stride = 1
input_filters = net.shape[3]
filter_size = self._filter_size
if 'separable' in operation:
num_layers = int(operation.split('_')[-1])
kernel_size = int(operation.split('x')[0][-1])
for layer_num in range(num_layers):
net = tf.nn.relu(net)
net = slim.separable_conv2d(
net,
filter_size,
kernel_size,
depth_multiplier=1,
scope='separable_{0}x{0}_{1}'.format(kernel_size, layer_num + 1),
stride=stride)
net = slim.batch_norm(
net, scope='bn_sep_{0}x{0}_{1}'.format(kernel_size, layer_num + 1))
stride = 1
elif 'atrous' in operation:
kernel_size = int(operation.split('x')[0][-1])
net = tf.nn.relu(net)
if stride == 2:
scaled_height = scale_dimension(tf.shape(net)[1], 0.5)
scaled_width = scale_dimension(tf.shape(net)[2], 0.5)
net = resize_bilinear(net, [scaled_height, scaled_width], net.dtype)
net = slim.conv2d(net, filter_size, kernel_size, rate=1,
scope='atrous_{0}x{0}'.format(kernel_size))
else:
net = slim.conv2d(net, filter_size, kernel_size, rate=2,
scope='atrous_{0}x{0}'.format(kernel_size))
net = slim.batch_norm(net, scope='bn_atr_{0}x{0}'.format(kernel_size))
elif operation in ['none']:
if stride > 1 or (input_filters != filter_size):
net = tf.nn.relu(net)
net = slim.conv2d(net, filter_size, 1, stride=stride, scope='1x1')
net = slim.batch_norm(net, scope='bn_1')
elif 'pool' in operation:
pooling_type = operation.split('_')[0]
pooling_shape = int(operation.split('_')[-1].split('x')[0])
if pooling_type == 'avg':
net = slim.avg_pool2d(net, pooling_shape, stride=stride, padding='SAME')
elif pooling_type == 'max':
net = slim.max_pool2d(net, pooling_shape, stride=stride, padding='SAME')
else:
raise ValueError('Unimplemented pooling type: ', pooling_type)
if input_filters != filter_size:
net = slim.conv2d(net, filter_size, 1, stride=1, scope='1x1')
net = slim.batch_norm(net, scope='bn_1')
else:
raise ValueError('Unimplemented operation', operation)
if operation != 'none':
net = self._apply_drop_path(net)
return net
def _combine_unused_states(self, net):
"""Concatenates the unused hidden states of the cell."""
used_hiddenstates = self._used_hiddenstates
states_to_combine = ([
h for h, is_used in zip(net, used_hiddenstates) if not is_used])
net = tf.concat(values=states_to_combine, axis=3)
return net
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Genotypes used by NAS."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from deeplab.core import nas_cell
class PNASCell(nas_cell.NASBaseCell):
"""Configuration and construction of the PNASNet-5 Cell."""
def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells,
total_training_steps):
# Name of operations: op_kernel-size_num-layers.
operations = [
'separable_5x5_2', 'max_pool_3x3', 'separable_7x7_2', 'max_pool_3x3',
'separable_5x5_2', 'separable_3x3_2', 'separable_3x3_2', 'max_pool_3x3',
'separable_3x3_2', 'none'
]
used_hiddenstates = [1, 1, 0, 0, 0, 0, 0]
hiddenstate_indices = [1, 1, 0, 0, 0, 0, 4, 0, 1, 0]
super(PNASCell, self).__init__(
num_conv_filters, operations, used_hiddenstates, hiddenstate_indices,
drop_path_keep_prob, total_num_cells, total_training_steps)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Network structure used by NAS."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from deeplab.core import nas_genotypes
from deeplab.core.nas_cell import NASBaseCell
from deeplab.core.utils import resize_bilinear
from deeplab.core.utils import scale_dimension
arg_scope = tf.contrib.framework.arg_scope
slim = tf.contrib.slim
def config(num_conv_filters=20,
total_training_steps=500000,
drop_path_keep_prob=1.0):
return tf.contrib.training.HParams(
# Multiplier when spatial size is reduced by 2.
filter_scaling_rate=2.0,
# Number of filters of the stem output tensor.
num_conv_filters=num_conv_filters,
# Probability to keep each path in the cell when training.
drop_path_keep_prob=drop_path_keep_prob,
# Total training steps to help drop path probability calculation.
total_training_steps=total_training_steps,
)
def nas_arg_scope(weight_decay=4e-5, batch_norm_decay=0.9997,
batch_norm_epsilon=0.001):
"""Default arg scope for the NAS models."""
batch_norm_params = {
# Decay for the moving averages.
'decay': batch_norm_decay,
# epsilon to prevent 0s in variance.
'epsilon': batch_norm_epsilon,
'scale': True,
'fused': True,
}
weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
weights_initializer = tf.contrib.layers.variance_scaling_initializer(
factor=1/3.0, mode='FAN_IN', uniform=True)
with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d],
weights_regularizer=weights_regularizer,
weights_initializer=weights_initializer):
with arg_scope([slim.fully_connected],
activation_fn=None, scope='FC'):
with arg_scope([slim.conv2d, slim.separable_conv2d],
activation_fn=None, biases_initializer=None):
with arg_scope([slim.batch_norm], **batch_norm_params) as sc:
return sc
def _nas_stem(inputs):
"""Stem used for NAS models."""
net = slim.conv2d(inputs, 64, [3, 3], stride=2,
scope='conv0', padding='SAME')
net = slim.batch_norm(net, scope='conv0_bn')
net = tf.nn.relu(net)
net = slim.conv2d(net, 64, [3, 3], stride=1,
scope='conv1', padding='SAME')
net = slim.batch_norm(net, scope='conv1_bn')
cell_outputs = [net]
net = tf.nn.relu(net)
net = slim.conv2d(net, 128, [3, 3], stride=2,
scope='conv2', padding='SAME')
net = slim.batch_norm(net, scope='conv2_bn')
cell_outputs.append(net)
return net, cell_outputs
def _build_nas_base(images,
cell,
backbone,
num_classes,
hparams,
global_pool=False,
reuse=None,
scope=None,
final_endpoint=None):
"""Constructs a NAS model.
Args:
images: A tensor of size [batch, height, width, channels].
cell: Cell structure used in the network.
backbone: Backbone structure used in the network. A list of integers in
which value 0 means "output_stride=4", value 1 means "output_stride=8",
value 2 means "output_stride=16", and value 3 means "output_stride=32".
num_classes: Number of classes to predict.
hparams: Hyperparameters needed to construct the network.
global_pool: If True, we perform global average pooling before computing the
logits. Set to True for image classification, False for dense prediction.
reuse: Whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given.
scope: Optional variable_scope.
final_endpoint: The endpoint to construct the network up to.
Returns:
net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
end_points: A dictionary from components of the network to the corresponding
activation.
"""
with tf.variable_scope(scope, 'nas', [images], reuse=reuse):
end_points = {}
def add_and_check_endpoint(endpoint_name, net):
end_points[endpoint_name] = net
return final_endpoint and (endpoint_name == final_endpoint)
net, cell_outputs = _nas_stem(images)
if add_and_check_endpoint('Stem', net):
return net, end_points
# Run the cells
filter_scaling = 1.0
for cell_num in range(len(backbone)):
stride = 1
if cell_num == 0:
if backbone[0] == 1:
stride = 2
filter_scaling *= hparams.filter_scaling_rate
else:
if backbone[cell_num] == backbone[cell_num - 1] + 1:
stride = 2
filter_scaling *= hparams.filter_scaling_rate
elif backbone[cell_num] == backbone[cell_num - 1] - 1:
scaled_height = scale_dimension(tf.shape(net)[1], 2)
scaled_width = scale_dimension(tf.shape(net)[2], 2)
net = resize_bilinear(net, [scaled_height, scaled_width], net.dtype)
filter_scaling /= hparams.filter_scaling_rate
net = cell(
net,
scope='cell_{}'.format(cell_num),
filter_scaling=filter_scaling,
stride=stride,
prev_layer=cell_outputs[-2],
cell_num=cell_num)
if add_and_check_endpoint('Cell_{}'.format(cell_num), net):
return net, end_points
cell_outputs.append(net)
net = tf.nn.relu(net)
if global_pool:
# Global average pooling.
net = tf.reduce_mean(net, [1, 2], name='global_pool', keepdims=True)
if num_classes is not None:
net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
normalizer_fn=None, scope='logits')
end_points['predictions'] = slim.softmax(net, scope='predictions')
return net, end_points
def pnasnet(images,
num_classes,
is_training=True,
global_pool=False,
output_stride=16,
nas_stem_output_num_conv_filters=20,
nas_training_hyper_parameters=None,
reuse=None,
scope='pnasnet',
final_endpoint=None):
"""Builds PNASNet model."""
hparams = config(num_conv_filters=nas_stem_output_num_conv_filters)
if nas_training_hyper_parameters:
hparams.set_hparam('drop_path_keep_prob',
nas_training_hyper_parameters['drop_path_keep_prob'])
hparams.set_hparam('total_training_steps',
nas_training_hyper_parameters['total_training_steps'])
if not is_training:
tf.logging.info('During inference, setting drop_path_keep_prob = 1.0.')
hparams.set_hparam('drop_path_keep_prob', 1.0)
tf.logging.info(hparams)
if output_stride == 8:
backbone = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
elif output_stride == 16:
backbone = [1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2]
elif output_stride == 32:
backbone = [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]
else:
raise ValueError('Unsupported output_stride ', output_stride)
cell = nas_genotypes.PNASCell(hparams.num_conv_filters,
hparams.drop_path_keep_prob,
len(backbone),
hparams.total_training_steps)
with arg_scope([slim.dropout, slim.batch_norm], is_training=is_training):
return _build_nas_base(
images,
cell=cell,
backbone=backbone,
num_classes=num_classes,
hparams=hparams,
global_pool=global_pool,
reuse=reuse,
scope=scope,
final_endpoint=final_endpoint)
# pylint: disable=unused-argument
def hnasnet(images,
num_classes,
is_training=True,
global_pool=False,
output_stride=16,
nas_stem_output_num_conv_filters=20,
nas_training_hyper_parameters=None,
reuse=None,
scope='hnasnet',
final_endpoint=None):
"""Builds hierarchical model."""
hparams = config(num_conv_filters=nas_stem_output_num_conv_filters)
if nas_training_hyper_parameters:
hparams.set_hparam('drop_path_keep_prob',
nas_training_hyper_parameters['drop_path_keep_prob'])
hparams.set_hparam('total_training_steps',
nas_training_hyper_parameters['total_training_steps'])
if not is_training:
tf.logging.info('During inference, setting drop_path_keep_prob = 1.0.')
hparams.set_hparam('drop_path_keep_prob', 1.0)
tf.logging.info(hparams)
operations = [
'atrous_5x5', 'separable_3x3_2', 'separable_3x3_2', 'atrous_3x3',
'separable_3x3_2', 'separable_3x3_2', 'separable_5x5_2',
'separable_5x5_2', 'separable_5x5_2', 'atrous_5x5'
]
used_hiddenstates = [1, 1, 0, 0, 0, 0, 0]
hiddenstate_indices = [1, 0, 1, 0, 3, 1, 4, 2, 3, 5]
backbone = [0, 0, 0, 1, 2, 1, 2, 2, 3, 3, 2, 1]
cell = NASBaseCell(hparams.num_conv_filters,
operations,
used_hiddenstates,
hiddenstate_indices,
hparams.drop_path_keep_prob,
len(backbone),
hparams.total_training_steps)
with arg_scope([slim.dropout, slim.batch_norm], is_training=is_training):
return _build_nas_base(
images,
cell=cell,
backbone=backbone,
num_classes=num_classes,
hparams=hparams,
global_pool=global_pool,
reuse=reuse,
scope=scope,
final_endpoint=final_endpoint)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for resnet_v1_beta module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import google3
import numpy as np
import tensorflow as tf
from deeplab.core import nas_genotypes
from deeplab.core import nas_network
arg_scope = tf.contrib.framework.arg_scope
slim = tf.contrib.slim
def create_test_input(batch, height, width, channels):
"""Creates test input tensor."""
if None in [batch, height, width, channels]:
return tf.placeholder(tf.float32, (batch, height, width, channels))
else:
return tf.to_float(
np.tile(
np.reshape(
np.reshape(np.arange(height), [height, 1]) +
np.reshape(np.arange(width), [1, width]),
[1, height, width, 1]),
[batch, 1, 1, channels]))
class NASNetworkTest(tf.test.TestCase):
"""Tests with complete small NAS networks."""
def _pnasnet_small(self,
images,
num_classes,
is_training=True,
output_stride=16,
final_endpoint=None):
"""Build PNASNet model backbone."""
hparams = tf.contrib.training.HParams(
filter_scaling_rate=2.0,
num_conv_filters=10,
drop_path_keep_prob=1.0,
total_training_steps=200000,
)
if not is_training:
hparams.set_hparam('drop_path_keep_prob', 1.0)
backbone = [1, 2, 2]
cell = nas_genotypes.PNASCell(hparams.num_conv_filters,
hparams.drop_path_keep_prob,
len(backbone),
hparams.total_training_steps)
with arg_scope([slim.dropout, slim.batch_norm], is_training=is_training):
return nas_network._build_nas_base(
images,
cell=cell,
backbone=backbone,
num_classes=num_classes,
hparams=hparams,
reuse=tf.AUTO_REUSE,
scope='pnasnet_small',
final_endpoint=final_endpoint)
def testFullyConvolutionalEndpointShapes(self):
num_classes = 10
inputs = create_test_input(2, 321, 321, 3)
with slim.arg_scope(nas_network.nas_arg_scope()):
_, end_points = self._pnasnet_small(inputs,
num_classes)
endpoint_to_shape = {
'Stem': [2, 81, 81, 128],
'Cell_0': [2, 41, 41, 100],
'Cell_1': [2, 21, 21, 200],
'Cell_2': [2, 21, 21, 200]}
for endpoint, shape in endpoint_to_shape.iteritems():
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
if __name__ == '__main__':
tf.test.main()
...@@ -19,6 +19,22 @@ import tensorflow as tf ...@@ -19,6 +19,22 @@ import tensorflow as tf
slim = tf.contrib.slim slim = tf.contrib.slim
def resize_bilinear(images, size, output_dtype=tf.float32):
"""Returns resized images as output_type.
Args:
images: A tensor of size [batch, height_in, width_in, channels].
size: A 1-D int32 Tensor of 2 elements: new_height, new_width. The new size
for the images.
output_dtype: The destination type.
Returns:
A tensor of size [batch, height_out, width_out, channels] as a dtype of
output_dtype.
"""
images = tf.image.resize_bilinear(images, size, align_corners=True)
return tf.cast(images, dtype=output_dtype)
def scale_dimension(dim, scale): def scale_dimension(dim, scale):
"""Scales the input dimension. """Scales the input dimension.
...@@ -36,13 +52,13 @@ def scale_dimension(dim, scale): ...@@ -36,13 +52,13 @@ def scale_dimension(dim, scale):
def split_separable_conv2d(inputs, def split_separable_conv2d(inputs,
filters, filters,
kernel_size=3, kernel_size=3,
rate=1, rate=1,
weight_decay=0.00004, weight_decay=0.00004,
depthwise_weights_initializer_stddev=0.33, depthwise_weights_initializer_stddev=0.33,
pointwise_weights_initializer_stddev=0.06, pointwise_weights_initializer_stddev=0.06,
scope=None): scope=None):
"""Splits a separable conv2d into depthwise and pointwise conv2d. """Splits a separable conv2d into depthwise and pointwise conv2d.
This operation differs from `tf.layers.separable_conv2d` as this operation This operation differs from `tf.layers.separable_conv2d` as this operation
...@@ -81,4 +97,4 @@ def split_separable_conv2d(inputs, ...@@ -81,4 +97,4 @@ def split_separable_conv2d(inputs,
weights_initializer=tf.truncated_normal_initializer( weights_initializer=tf.truncated_normal_initializer(
stddev=pointwise_weights_initializer_stddev), stddev=pointwise_weights_initializer_stddev),
weights_regularizer=slim.l2_regularizer(weight_decay), weights_regularizer=slim.l2_regularizer(weight_decay),
scope=scope + '_pointwise') scope=scope + '_pointwise')
\ No newline at end of file
# Copyright 2018 The TensorFlow Authors All Rights Reserved. # Copyright 2018 The TensorFlow Authors All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -29,4 +28,4 @@ class UtilsTest(tf.test.TestCase): ...@@ -29,4 +28,4 @@ class UtilsTest(tf.test.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
\ No newline at end of file
...@@ -52,6 +52,8 @@ slim = tf.contrib.slim ...@@ -52,6 +52,8 @@ slim = tf.contrib.slim
_DEFAULT_MULTI_GRID = [1, 1, 1] _DEFAULT_MULTI_GRID = [1, 1, 1]
# The cap for tf.clip_by_value.
_CLIP_CAP = 6
class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])): class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])):
...@@ -200,7 +202,9 @@ def xception_module(inputs, ...@@ -200,7 +202,9 @@ def xception_module(inputs,
activation_fn_in_separable_conv=False, activation_fn_in_separable_conv=False,
regularize_depthwise=False, regularize_depthwise=False,
outputs_collections=None, outputs_collections=None,
scope=None): scope=None,
use_bounded_activation=False,
use_explicit_padding=True):
"""An Xception module. """An Xception module.
The output of one Xception module is equal to the sum of `residual` and The output of one Xception module is equal to the sum of `residual` and
...@@ -230,6 +234,11 @@ def xception_module(inputs, ...@@ -230,6 +234,11 @@ def xception_module(inputs,
depthwise convolution weights. depthwise convolution weights.
outputs_collections: Collection to add the Xception unit output. outputs_collections: Collection to add the Xception unit output.
scope: Optional variable_scope. scope: Optional variable_scope.
use_bounded_activation: Whether or not to use bounded activations. Bounded
activations better lend themselves to quantized inference.
use_explicit_padding: If True, use explicit padding to make the model fully
compatible with the open source version, otherwise use the native
Tensorflow 'SAME' padding.
Returns: Returns:
The Xception module's output. The Xception module's output.
...@@ -250,11 +259,19 @@ def xception_module(inputs, ...@@ -250,11 +259,19 @@ def xception_module(inputs,
def _separable_conv(features, depth, kernel_size, depth_multiplier, def _separable_conv(features, depth, kernel_size, depth_multiplier,
regularize_depthwise, rate, stride, scope): regularize_depthwise, rate, stride, scope):
"""Separable conv block."""
if activation_fn_in_separable_conv: if activation_fn_in_separable_conv:
activation_fn = tf.nn.relu activation_fn = tf.nn.relu6 if use_bounded_activation else tf.nn.relu
else: else:
activation_fn = None if use_bounded_activation:
features = tf.nn.relu(features) # When use_bounded_activation is True, we clip the feature values and
# apply relu6 for activation.
activation_fn = lambda x: tf.clip_by_value(x, -_CLIP_CAP, _CLIP_CAP)
features = tf.nn.relu6(features)
else:
# Original network design.
activation_fn = None
features = tf.nn.relu(features)
return separable_conv2d_same(features, return separable_conv2d_same(features,
depth, depth,
kernel_size, kernel_size,
...@@ -262,6 +279,7 @@ def xception_module(inputs, ...@@ -262,6 +279,7 @@ def xception_module(inputs,
stride=stride, stride=stride,
rate=rate, rate=rate,
activation_fn=activation_fn, activation_fn=activation_fn,
use_explicit_padding=use_explicit_padding,
regularize_depthwise=regularize_depthwise, regularize_depthwise=regularize_depthwise,
scope=scope) scope=scope)
for i in range(3): for i in range(3):
...@@ -280,9 +298,19 @@ def xception_module(inputs, ...@@ -280,9 +298,19 @@ def xception_module(inputs,
stride=stride, stride=stride,
activation_fn=None, activation_fn=None,
scope='shortcut') scope='shortcut')
if use_bounded_activation:
residual = tf.clip_by_value(residual, -_CLIP_CAP, _CLIP_CAP)
shortcut = tf.clip_by_value(shortcut, -_CLIP_CAP, _CLIP_CAP)
outputs = residual + shortcut outputs = residual + shortcut
if use_bounded_activation:
outputs = tf.nn.relu6(outputs)
elif skip_connection_type == 'sum': elif skip_connection_type == 'sum':
if use_bounded_activation:
residual = tf.clip_by_value(residual, -_CLIP_CAP, _CLIP_CAP)
inputs = tf.clip_by_value(inputs, -_CLIP_CAP, _CLIP_CAP)
outputs = residual + inputs outputs = residual + inputs
if use_bounded_activation:
outputs = tf.nn.relu6(outputs)
elif skip_connection_type == 'none': elif skip_connection_type == 'none':
outputs = residual outputs = residual
else: else:
...@@ -713,9 +741,9 @@ def xception_arg_scope(weight_decay=0.00004, ...@@ -713,9 +741,9 @@ def xception_arg_scope(weight_decay=0.00004,
batch_norm_epsilon=0.001, batch_norm_epsilon=0.001,
batch_norm_scale=True, batch_norm_scale=True,
weights_initializer_stddev=0.09, weights_initializer_stddev=0.09,
activation_fn=tf.nn.relu,
regularize_depthwise=False, regularize_depthwise=False,
use_batch_norm=True): use_batch_norm=True,
use_bounded_activation=False):
"""Defines the default Xception arg scope. """Defines the default Xception arg scope.
Args: Args:
...@@ -728,10 +756,11 @@ def xception_arg_scope(weight_decay=0.00004, ...@@ -728,10 +756,11 @@ def xception_arg_scope(weight_decay=0.00004,
activations in the batch normalization layer. activations in the batch normalization layer.
weights_initializer_stddev: The standard deviation of the trunctated normal weights_initializer_stddev: The standard deviation of the trunctated normal
weight initializer. weight initializer.
activation_fn: The activation function in Xception.
regularize_depthwise: Whether or not apply L2-norm regularization on the regularize_depthwise: Whether or not apply L2-norm regularization on the
depthwise convolution weights. depthwise convolution weights.
use_batch_norm: Whether or not to use batch normalization. use_batch_norm: Whether or not to use batch normalization.
use_bounded_activation: Whether or not to use bounded activations. Bounded
activations better lend themselves to quantized inference.
Returns: Returns:
An `arg_scope` to use for the Xception models. An `arg_scope` to use for the Xception models.
...@@ -745,6 +774,7 @@ def xception_arg_scope(weight_decay=0.00004, ...@@ -745,6 +774,7 @@ def xception_arg_scope(weight_decay=0.00004,
depthwise_regularizer = slim.l2_regularizer(weight_decay) depthwise_regularizer = slim.l2_regularizer(weight_decay)
else: else:
depthwise_regularizer = None depthwise_regularizer = None
activation_fn = tf.nn.relu6 if use_bounded_activation else tf.nn.relu
with slim.arg_scope( with slim.arg_scope(
[slim.conv2d, slim.separable_conv2d], [slim.conv2d, slim.separable_conv2d],
weights_initializer=tf.truncated_normal_initializer( weights_initializer=tf.truncated_normal_initializer(
...@@ -757,5 +787,9 @@ def xception_arg_scope(weight_decay=0.00004, ...@@ -757,5 +787,9 @@ def xception_arg_scope(weight_decay=0.00004,
weights_regularizer=slim.l2_regularizer(weight_decay)): weights_regularizer=slim.l2_regularizer(weight_decay)):
with slim.arg_scope( with slim.arg_scope(
[slim.separable_conv2d], [slim.separable_conv2d],
weights_regularizer=depthwise_regularizer) as arg_sc: weights_regularizer=depthwise_regularizer):
return arg_sc with slim.arg_scope(
[xception_module],
use_bounded_activation=use_bounded_activation,
use_explicit_padding=not use_bounded_activation) as arg_sc:
return arg_sc
...@@ -462,6 +462,24 @@ class XceptionNetworkTest(tf.test.TestCase): ...@@ -462,6 +462,24 @@ class XceptionNetworkTest(tf.test.TestCase):
reuse=True) reuse=True)
self.assertItemsEqual(end_points0.keys(), end_points1.keys()) self.assertItemsEqual(end_points0.keys(), end_points1.keys())
def testUseBoundedAcitvation(self):
global_pool = False
num_classes = 10
output_stride = 8
for use_bounded_activation in (True, False):
tf.reset_default_graph()
inputs = create_test_input(2, 321, 321, 3)
with slim.arg_scope(xception.xception_arg_scope(
use_bounded_activation=use_bounded_activation)):
_, _ = self._xception_small(
inputs,
num_classes,
global_pool=global_pool,
output_stride=output_stride,
scope='xception')
for node in tf.get_default_graph().as_graph_def().node:
if node.op.startswith('Relu'):
self.assertEqual(node.op == 'Relu6', use_bounded_activation)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrapper for providing semantic segmentaion data.
The SegmentationDataset class provides both images and annotations (semantic
segmentation and/or instance segmentation) for TensorFlow. Currently, we
support the following datasets:
1. PASCAL VOC 2012 (http://host.robots.ox.ac.uk/pascal/VOC/voc2012/).
PASCAL VOC 2012 semantic segmentation dataset annotates 20 foreground objects
(e.g., bike, person, and so on) and leaves all the other semantic classes as
one background class. The dataset contains 1464, 1449, and 1456 annotated
images for the training, validation and test respectively.
2. Cityscapes dataset (https://www.cityscapes-dataset.com)
The Cityscapes dataset contains 19 semantic labels (such as road, person, car,
and so on) for urban street scenes.
3. ADE20K dataset (http://groups.csail.mit.edu/vision/datasets/ADE20K)
The ADE20K dataset contains 150 semantic labels both urban street scenes and
indoor scenes.
References:
M. Everingham, S. M. A. Eslami, L. V. Gool, C. K. I. Williams, J. Winn,
and A. Zisserman, The pascal visual object classes challenge a retrospective.
IJCV, 2014.
M. Cordts, M. Omran, S. Ramos, T. Rehfeld, M. Enzweiler, R. Benenson,
U. Franke, S. Roth, and B. Schiele, "The cityscapes dataset for semantic urban
scene understanding," In Proc. of CVPR, 2016.
B. Zhou, H. Zhao, X. Puig, S. Fidler, A. Barriuso, A. Torralba, "Scene Parsing
through ADE20K dataset", In Proc. of CVPR, 2017.
"""
import collections
import os
import tensorflow as tf
from deeplab import common
from deeplab import input_preprocess
# Named tuple to describe the dataset properties.
DatasetDescriptor = collections.namedtuple(
'DatasetDescriptor',
[
'splits_to_sizes', # Splits of the dataset into training, val and test.
'num_classes', # Number of semantic classes, including the
# background class (if exists). For example, there
# are 20 foreground classes + 1 background class in
# the PASCAL VOC 2012 dataset. Thus, we set
# num_classes=21.
'ignore_label', # Ignore label value.
])
_CITYSCAPES_INFORMATION = DatasetDescriptor(
splits_to_sizes={
'train': 2975,
'val': 500,
},
num_classes=19,
ignore_label=255,
)
_PASCAL_VOC_SEG_INFORMATION = DatasetDescriptor(
splits_to_sizes={
'train': 1464,
'train_aug': 10582,
'trainval': 2913,
'val': 1449,
},
num_classes=21,
ignore_label=255,
)
_ADE20K_INFORMATION = DatasetDescriptor(
splits_to_sizes={
'train': 20210, # num of samples in images/training
'val': 2000, # num of samples in images/validation
},
num_classes=151,
ignore_label=0,
)
_DATASETS_INFORMATION = {
'cityscapes': _CITYSCAPES_INFORMATION,
'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION,
'ade20k': _ADE20K_INFORMATION,
}
# Default file pattern of TFRecord of TensorFlow Example.
_FILE_PATTERN = '%s-*'
def get_cityscapes_dataset_name():
return 'cityscapes'
class Dataset(object):
"""Represents input dataset for deeplab model."""
def __init__(self,
dataset_name,
split_name,
dataset_dir,
batch_size,
crop_size,
min_resize_value=None,
max_resize_value=None,
resize_factor=None,
min_scale_factor=1.,
max_scale_factor=1.,
scale_factor_step_size=0,
model_variant=None,
num_readers=1,
is_training=False,
should_shuffle=False,
should_repeat=False):
"""Initializes the dataset.
Args:
dataset_name: Dataset name.
split_name: A train/val Split name.
dataset_dir: The directory of the dataset sources.
batch_size: Batch size.
crop_size: The size used to crop the image and label.
min_resize_value: Desired size of the smaller image side.
max_resize_value: Maximum allowed size of the larger image side.
resize_factor: Resized dimensions are multiple of factor plus one.
min_scale_factor: Minimum scale factor value.
max_scale_factor: Maximum scale factor value.
scale_factor_step_size: The step size from min scale factor to max scale
factor. The input is randomly scaled based on the value of
(min_scale_factor, max_scale_factor, scale_factor_step_size).
model_variant: Model variant (string) for choosing how to mean-subtract
the images. See feature_extractor.network_map for supported model
variants.
num_readers: Number of readers for data provider.
is_training: Boolean, if dataset is for training or not.
should_shuffle: Boolean, if should shuffle the input data.
should_repeat: Boolean, if should repeat the input data.
Raises:
ValueError: Dataset name and split name are not supported.
"""
if dataset_name not in _DATASETS_INFORMATION:
raise ValueError('The specified dataset is not supported yet.')
self.dataset_name = dataset_name
splits_to_sizes = _DATASETS_INFORMATION[dataset_name].splits_to_sizes
if split_name not in splits_to_sizes:
raise ValueError('data split name %s not recognized' % split_name)
if model_variant is None:
tf.logging.warning('Please specify a model_variant. See '
'feature_extractor.network_map for supported model '
'variants.')
self.split_name = split_name
self.dataset_dir = dataset_dir
self.batch_size = batch_size
self.crop_size = crop_size
self.min_resize_value = min_resize_value
self.max_resize_value = max_resize_value
self.resize_factor = resize_factor
self.min_scale_factor = min_scale_factor
self.max_scale_factor = max_scale_factor
self.scale_factor_step_size = scale_factor_step_size
self.model_variant = model_variant
self.num_readers = num_readers
self.is_training = is_training
self.should_shuffle = should_shuffle
self.should_repeat = should_repeat
self.num_of_classes = _DATASETS_INFORMATION[self.dataset_name].num_classes
self.ignore_label = _DATASETS_INFORMATION[self.dataset_name].ignore_label
def _parse_function(self, example_proto):
"""Function to parse the example proto.
Args:
example_proto: Proto in the format of tf.Example.
Returns:
A dictionary with parsed image, label, height, width and image name.
Raises:
ValueError: Label is of wrong shape.
"""
# Currently only supports jpeg and png.
# Need to use this logic because the shape is not known for
# tf.image.decode_image and we rely on this info to
# extend label if necessary.
def _decode_image(content, channels):
return tf.cond(
tf.image.is_jpeg(content),
lambda: tf.image.decode_jpeg(content, channels),
lambda: tf.image.decode_png(content, channels))
features = {
'image/encoded':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/filename':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/format':
tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/height':
tf.FixedLenFeature((), tf.int64, default_value=0),
'image/width':
tf.FixedLenFeature((), tf.int64, default_value=0),
'image/segmentation/class/encoded':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/segmentation/class/format':
tf.FixedLenFeature((), tf.string, default_value='png'),
}
parsed_features = tf.parse_single_example(example_proto, features)
image = _decode_image(parsed_features['image/encoded'], channels=3)
label = None
if self.split_name != common.TEST_SET:
label = _decode_image(
parsed_features['image/segmentation/class/encoded'], channels=1)
image_name = parsed_features['image/filename']
if image_name is None:
image_name = tf.constant('')
sample = {
common.IMAGE: image,
common.IMAGE_NAME: image_name,
common.HEIGHT: parsed_features['image/height'],
common.WIDTH: parsed_features['image/width'],
}
if label is not None:
if label.get_shape().ndims == 2:
label = tf.expand_dims(label, 2)
elif label.get_shape().ndims == 3 and label.shape.dims[2] == 1:
pass
else:
raise ValueError('Input label shape must be [height, width], or '
'[height, width, 1].')
label.set_shape([None, None, 1])
sample[common.LABELS_CLASS] = label
return sample
def _preprocess_image(self, sample):
"""Preprocesses the image and label.
Args:
sample: A sample containing image and label.
Returns:
sample: Sample with preprocessed image and label.
Raises:
ValueError: Ground truth label not provided during training.
"""
image = sample[common.IMAGE]
label = sample[common.LABELS_CLASS]
original_image, image, label = input_preprocess.preprocess_image_and_label(
image=image,
label=label,
crop_height=self.crop_size[0],
crop_width=self.crop_size[1],
min_resize_value=self.min_resize_value,
max_resize_value=self.max_resize_value,
resize_factor=self.resize_factor,
min_scale_factor=self.min_scale_factor,
max_scale_factor=self.max_scale_factor,
scale_factor_step_size=self.scale_factor_step_size,
ignore_label=self.ignore_label,
is_training=self.is_training,
model_variant=self.model_variant)
sample[common.IMAGE] = image
if not self.is_training:
# Original image is only used during visualization.
sample[common.ORIGINAL_IMAGE] = original_image
if label is not None:
sample[common.LABEL] = label
# Remove common.LABEL_CLASS key in the sample since it is only used to
# derive label and not used in training and evaluation.
sample.pop(common.LABELS_CLASS, None)
return sample
def get_one_shot_iterator(self):
"""Gets an iterator that iterates across the dataset once.
Returns:
An iterator of type tf.data.Iterator.
"""
files = self._get_all_files()
dataset = (
tf.data.TFRecordDataset(files, num_parallel_reads=self.num_readers)
.map(self._parse_function, num_parallel_calls=self.num_readers)
.map(self._preprocess_image, num_parallel_calls=self.num_readers))
if self.should_shuffle:
dataset = dataset.shuffle(buffer_size=100)
if self.should_repeat:
dataset = dataset.repeat() # Repeat forever for training.
else:
dataset = dataset.repeat(1)
dataset = dataset.batch(self.batch_size).prefetch(self.batch_size)
return dataset.make_one_shot_iterator()
def _get_all_files(self):
"""Gets all the files to read data from.
Returns:
A list of input files.
"""
file_pattern = _FILE_PATTERN
file_pattern = os.path.join(self.dataset_dir,
file_pattern % self.split_name)
return tf.gfile.Glob(file_pattern)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for deeplab.datasets.data_generator."""
from __future__ import print_function
import collections
import google3
import tensorflow as tf
from deeplab import common
from deeplab.datasets import data_generator
ImageAttributes = collections.namedtuple(
'ImageAttributes', ['image', 'label', 'height', 'width', 'image_name'])
class DatasetTest(tf.test.TestCase):
# Note: training dataset cannot be tested since there is shuffle operation.
# When disabling the shuffle, training dataset is operated same as validation
# dataset. Therefore it is not tested again.
def testPascalVocSegTestData(self):
dataset = data_generator.Dataset(
dataset_name='pascal_voc_seg',
split_name='val',
dataset_dir=
'research/deeplab/testing/pascal_voc_seg',
batch_size=1,
crop_size=[3, 3], # Use small size for testing.
min_resize_value=3,
max_resize_value=3,
resize_factor=None,
min_scale_factor=0.01,
max_scale_factor=2.0,
scale_factor_step_size=0.25,
is_training=False,
model_variant='mobilenet_v2')
self.assertAllEqual(dataset.num_of_classes, 21)
self.assertAllEqual(dataset.ignore_label, 255)
num_of_images = 3
with self.test_session() as sess:
iterator = dataset.get_one_shot_iterator()
for i in range(num_of_images):
batch = iterator.get_next()
batch, = sess.run([batch])
image_attributes = _get_attributes_of_image(i)
self.assertAllEqual(batch[common.IMAGE][0], image_attributes.image)
self.assertAllEqual(batch[common.LABEL][0], image_attributes.label)
self.assertEqual(batch[common.HEIGHT][0], image_attributes.height)
self.assertEqual(batch[common.WIDTH][0], image_attributes.width)
self.assertEqual(batch[common.IMAGE_NAME][0],
image_attributes.image_name)
# All data have been read.
with self.assertRaisesRegexp(tf.errors.OutOfRangeError, ''):
sess.run([iterator.get_next()])
def _get_attributes_of_image(index):
"""Gets the attributes of the image.
Args:
index: Index of image in all images.
Returns:
Attributes of the image in the format of ImageAttributes.
Raises:
ValueError: If index is of wrong value.
"""
if index == 0:
return ImageAttributes(
image=IMAGE_1,
label=LABEL_1,
height=366,
width=500,
image_name='2007_000033')
elif index == 1:
return ImageAttributes(
image=IMAGE_2,
label=LABEL_2,
height=335,
width=500,
image_name='2007_000042')
elif index == 2:
return ImageAttributes(
image=IMAGE_3,
label=LABEL_3,
height=333,
width=500,
image_name='2007_000061')
else:
raise ValueError('Index can only be 0, 1 or 2.')
IMAGE_1 = (
(
(57., 41., 18.),
(151.5, 138., 111.5),
(107., 158., 143.),
),
(
(104.5, 141., 191.),
(101.75, 72.5, 120.75),
(86.5, 139.5, 120.),
),
(
(96., 85., 145.),
(123.5, 107.5, 97.),
(61., 148., 116.),
),
)
LABEL_1 = (
(
(70,),
(227,),
(251,),
),
(
(101,),
(0,),
(10,),
),
(
(145,),
(245,),
(146,),
),
)
IMAGE_2 = (
(
(94., 64., 98.),
(145.5, 136.5, 134.5),
(108., 162., 172.),
),
(
(168., 157., 213.),
(161.5, 154.5, 148.),
(25., 46., 93.),
),
(
(255., 204., 237.),
(124., 102., 126.5),
(155., 181., 82.),
),
)
LABEL_2 = (
(
(44,),
(146,),
(121,),
),
(
(108,),
(118,),
(6,),
),
(
(246,),
(121,),
(108,),
),
)
IMAGE_3 = (
(
(235., 173., 150.),
(145.5, 83.5, 102.),
(82., 149., 158.),
),
(
(130., 95., 14.),
(132.5, 141.5, 93.),
(119., 85., 86.),
),
(
(127.5, 127.5, 127.5),
(127.5, 127.5, 127.5),
(127.5, 127.5, 127.5),
),
)
LABEL_3 = (
(
(91,),
(120,),
(132,),
),
(
(135,),
(139,),
(72,),
),
(
(255,),
(255,),
(255,),
),
)
if __name__ == '__main__':
tf.test.main()
...@@ -17,18 +17,12 @@ ...@@ -17,18 +17,12 @@
See model.py for more details and usage. See model.py for more details and usage.
""" """
import math
import six
import tensorflow as tf import tensorflow as tf
from deeplab import common from deeplab import common
from deeplab import model from deeplab import model
from deeplab.datasets import segmentation_dataset from deeplab.datasets import data_generator
from deeplab.utils import input_generator
slim = tf.contrib.slim
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string('master', '', 'BNS name of the tensorflow server') flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
...@@ -84,31 +78,40 @@ flags.DEFINE_integer('max_number_of_evaluations', 0, ...@@ -84,31 +78,40 @@ flags.DEFINE_integer('max_number_of_evaluations', 0,
def main(unused_argv): def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
# Get dataset-dependent information.
dataset = segmentation_dataset.get_dataset( dataset = data_generator.Dataset(
FLAGS.dataset, FLAGS.eval_split, dataset_dir=FLAGS.dataset_dir) dataset_name=FLAGS.dataset,
split_name=FLAGS.eval_split,
dataset_dir=FLAGS.dataset_dir,
batch_size=FLAGS.eval_batch_size,
crop_size=FLAGS.eval_crop_size,
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
model_variant=FLAGS.model_variant,
num_readers=2,
is_training=False,
should_shuffle=False,
should_repeat=False)
tf.gfile.MakeDirs(FLAGS.eval_logdir) tf.gfile.MakeDirs(FLAGS.eval_logdir)
tf.logging.info('Evaluating on %s set', FLAGS.eval_split) tf.logging.info('Evaluating on %s set', FLAGS.eval_split)
with tf.Graph().as_default(): with tf.Graph().as_default():
samples = input_generator.get( samples = dataset.get_one_shot_iterator().get_next()
dataset,
FLAGS.eval_crop_size,
FLAGS.eval_batch_size,
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
dataset_split=FLAGS.eval_split,
is_training=False,
model_variant=FLAGS.model_variant)
model_options = common.ModelOptions( model_options = common.ModelOptions(
outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_classes}, outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_of_classes},
crop_size=FLAGS.eval_crop_size, crop_size=FLAGS.eval_crop_size,
atrous_rates=FLAGS.atrous_rates, atrous_rates=FLAGS.atrous_rates,
output_stride=FLAGS.output_stride) output_stride=FLAGS.output_stride)
# Set shape in order for tf.contrib.tfprof.model_analyzer to work properly.
samples[common.IMAGE].set_shape(
[FLAGS.eval_batch_size,
FLAGS.eval_crop_size[0],
FLAGS.eval_crop_size[1],
3])
if tuple(FLAGS.eval_scales) == (1.0,): if tuple(FLAGS.eval_scales) == (1.0,):
tf.logging.info('Performing single-scale test.') tf.logging.info('Performing single-scale test.')
predictions = model.predict_labels(samples[common.IMAGE], model_options, predictions = model.predict_labels(samples[common.IMAGE], model_options,
...@@ -138,34 +141,32 @@ def main(unused_argv): ...@@ -138,34 +141,32 @@ def main(unused_argv):
predictions_tag += '_flipped' predictions_tag += '_flipped'
# Define the evaluation metric. # Define the evaluation metric.
metric_map = {} miou, update_op = tf.metrics.mean_iou(
metric_map[predictions_tag] = tf.metrics.mean_iou( predictions, labels, dataset.num_of_classes, weights=weights)
predictions, labels, dataset.num_classes, weights=weights) tf.summary.scalar(predictions_tag, miou)
metrics_to_values, metrics_to_updates = (
tf.contrib.metrics.aggregate_metric_map(metric_map))
for metric_name, metric_value in six.iteritems(metrics_to_values): summary_op = tf.summary.merge_all()
slim.summaries.add_scalar_summary( summary_hook = tf.contrib.training.SummaryAtEndHook(
metric_value, metric_name, print_summary=True) log_dir=FLAGS.eval_logdir, summary_op=summary_op)
hooks = [summary_hook]
num_batches = int(
math.ceil(dataset.num_samples / float(FLAGS.eval_batch_size)))
tf.logging.info('Eval num images %d', dataset.num_samples)
tf.logging.info('Eval batch size %d and num batch %d',
FLAGS.eval_batch_size, num_batches)
num_eval_iters = None num_eval_iters = None
if FLAGS.max_number_of_evaluations > 0: if FLAGS.max_number_of_evaluations > 0:
num_eval_iters = FLAGS.max_number_of_evaluations num_eval_iters = FLAGS.max_number_of_evaluations
slim.evaluation.evaluation_loop(
tf.contrib.tfprof.model_analyzer.print_model_analysis(
tf.get_default_graph(),
tfprof_options=tf.contrib.tfprof.model_analyzer.
TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
tf.contrib.tfprof.model_analyzer.print_model_analysis(
tf.get_default_graph(),
tfprof_options=tf.contrib.tfprof.model_analyzer.FLOAT_OPS_OPTIONS)
tf.contrib.training.evaluate_repeatedly(
master=FLAGS.master, master=FLAGS.master,
checkpoint_dir=FLAGS.checkpoint_dir, checkpoint_dir=FLAGS.checkpoint_dir,
logdir=FLAGS.eval_logdir, eval_ops=[update_op],
num_evals=num_batches,
eval_op=list(metrics_to_updates.values()),
max_number_of_evaluations=num_eval_iters, max_number_of_evaluations=num_eval_iters,
hooks=hooks,
eval_interval_secs=FLAGS.eval_interval_secs) eval_interval_secs=FLAGS.eval_interval_secs)
......
...@@ -53,11 +53,15 @@ flags.DEFINE_multi_float('inference_scales', [1.0], ...@@ -53,11 +53,15 @@ flags.DEFINE_multi_float('inference_scales', [1.0],
flags.DEFINE_bool('add_flipped_images', False, flags.DEFINE_bool('add_flipped_images', False,
'Add flipped images during inference or not.') 'Add flipped images during inference or not.')
flags.DEFINE_bool('save_inference_graph', False,
'Save inference graph in text proto.')
# Input name of the exported model. # Input name of the exported model.
_INPUT_NAME = 'ImageTensor' _INPUT_NAME = 'ImageTensor'
# Output name of the exported model. # Output name of the exported model.
_OUTPUT_NAME = 'SemanticPredictions' _OUTPUT_NAME = 'SemanticPredictions'
_RAW_OUTPUT_NAME = 'RawSemanticPredictions'
def _create_input_tensors(): def _create_input_tensors():
...@@ -128,8 +132,10 @@ def main(unused_argv): ...@@ -128,8 +132,10 @@ def main(unused_argv):
predictions = tf.cast(predictions[common.OUTPUT_TYPE], tf.float32) predictions = tf.cast(predictions[common.OUTPUT_TYPE], tf.float32)
# Crop the valid regions from the predictions. # Crop the valid regions from the predictions.
raw_predictions = tf.identity(
predictions[common.OUTPUT_TYPE], _RAW_OUTPUT_NAME)
semantic_predictions = tf.slice( semantic_predictions = tf.slice(
predictions, raw_predictions,
[0, 0, 0], [0, 0, 0],
[1, resized_image_size[0], resized_image_size[1]]) [1, resized_image_size[0], resized_image_size[1]])
# Resize back the prediction to the original image size. # Resize back the prediction to the original image size.
...@@ -147,9 +153,11 @@ def main(unused_argv): ...@@ -147,9 +153,11 @@ def main(unused_argv):
saver = tf.train.Saver(tf.model_variables()) saver = tf.train.Saver(tf.model_variables())
tf.gfile.MakeDirs(os.path.dirname(FLAGS.export_path)) dirname = os.path.dirname(FLAGS.export_path)
tf.gfile.MakeDirs(dirname)
graph_def = tf.get_default_graph().as_graph_def(add_shapes=True)
freeze_graph.freeze_graph_with_def_protos( freeze_graph.freeze_graph_with_def_protos(
tf.get_default_graph().as_graph_def(add_shapes=True), graph_def,
saver.as_saver_def(), saver.as_saver_def(),
FLAGS.checkpoint_path, FLAGS.checkpoint_path,
_OUTPUT_NAME, _OUTPUT_NAME,
...@@ -159,6 +167,9 @@ def main(unused_argv): ...@@ -159,6 +167,9 @@ def main(unused_argv):
clear_devices=True, clear_devices=True,
initializer_nodes=None) initializer_nodes=None)
if FLAGS.save_inference_graph:
tf.train.write_graph(graph_def, dirname, 'inference_graph.pbtxt')
if __name__ == '__main__': if __name__ == '__main__':
flags.mark_flag_as_required('checkpoint_path') flags.mark_flag_as_required('checkpoint_path')
......
...@@ -79,12 +79,14 @@ ${PATH_TO_DATASET} is the directory in which the Cityscapes dataset resides. ...@@ -79,12 +79,14 @@ ${PATH_TO_DATASET} is the directory in which the Cityscapes dataset resides.
3. The users could skip the flag, `decoder_output_stride`, if you do not want 3. The users could skip the flag, `decoder_output_stride`, if you do not want
to use the decoder structure. to use the decoder structure.
4. Change and add the following flags in order to use the provided dense prediction cell. 4. Change and add the following flags in order to use the provided dense
prediction cell. Note we need to set decoder_output_stride if you want to
use the provided checkpoints which include the decoder module.
```bash ```bash
--model_variant="xception_71" --model_variant="xception_71"
--dense_prediction_cell_json="deeplab/core/dense_prediction_cell_branch5_top1_cityscapes.json" --dense_prediction_cell_json="deeplab/core/dense_prediction_cell_branch5_top1_cityscapes.json"
--decoder_output_stride=4
``` ```
A local evaluation job using `xception_65` can be run with the following A local evaluation job using `xception_65` can be run with the following
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment