Commit 146a37c6 authored by Yukun Zhu's avatar Yukun Zhu Committed by aquariusjay
Browse files

Update new models and internal changes (#7788)

* Update for py3 and some internal changes

* \nInternal refactor\n

PiperOrigin-RevId: 279809827

* clean up
parent 1498d941
......@@ -157,6 +157,17 @@ under tensorflow/models. Please refer to the LICENSE for details.
## Change Logs
### March 27, 2019
* Supported using different loss weights on different classes during training.
**Contributor**: Yuwei Yang.
### March 26, 2019
* Supported ResNet-v1-18. **Contributor**: Michalis Raptis.
### March 6, 2019
* Released the evaluation code (under the `evaluation` folder) for image
......
......@@ -34,6 +34,9 @@ flags.DEFINE_integer('max_resize_value', None,
flags.DEFINE_integer('resize_factor', None,
'Resized dimensions are multiple of factor plus one.')
flags.DEFINE_boolean('keep_aspect_ratio', True,
'Keep aspect ratio after resizing or not.')
# Model dependent flags.
flags.DEFINE_integer('logits_kernel_size', 1,
......@@ -99,11 +102,8 @@ flags.DEFINE_enum('merge_method', 'max', ['max', 'avg'],
flags.DEFINE_boolean(
'prediction_with_upsampled_logits', True,
'When performing prediction, there are two options: (1) bilinear '
'upsampling the logits followed by argmax, or (2) armax followed by '
'nearest upsampling the predicted labels. The second option may introduce '
'some "blocking effect", but it is more computationally efficient. '
'Currently, prediction_with_upsampled_logits=False is only supported for '
'single-scale inference.')
'upsampling the logits followed by softmax, or (2) softmax followed by '
'bilinear upsampling.')
flags.DEFINE_string(
'dense_prediction_cell_json',
......@@ -114,10 +114,43 @@ flags.DEFINE_integer(
'nas_stem_output_num_conv_filters', 20,
'Number of filters of the stem output tensor in NAS models.')
flags.DEFINE_bool('nas_use_classification_head', False,
'Use image classification head for NAS model variants.')
flags.DEFINE_bool('nas_remove_os32_stride', False,
'Remove the stride in the output stride 32 branch.')
flags.DEFINE_bool('use_bounded_activation', False,
'Whether or not to use bounded activations. Bounded '
'activations better lend themselves to quantized inference.')
flags.DEFINE_boolean('aspp_with_concat_projection', True,
'ASPP with concat projection.')
flags.DEFINE_boolean('aspp_with_squeeze_and_excitation', False,
'ASPP with squeeze and excitation.')
flags.DEFINE_integer('aspp_convs_filters', 256, 'ASPP convolution filters.')
flags.DEFINE_boolean('decoder_use_sum_merge', False,
'Decoder uses simply sum merge.')
flags.DEFINE_integer('decoder_filters', 256, 'Decoder filters.')
flags.DEFINE_boolean('decoder_output_is_logits', False,
'Use decoder output as logits or not.')
flags.DEFINE_boolean('image_se_uses_qsigmoid', False, 'Use q-sigmoid.')
flags.DEFINE_multi_float(
'label_weights', None,
'A list of label weights, each element represents the weight for the label '
'of its index, for example, label_weights = [0.1, 0.5] means the weight '
'for label 0 is 0.1 and the weight for label 1 is 0.5. If set as None, all '
'the labels have the same weight 1.0.')
flags.DEFINE_float('batch_norm_decay', 0.9997, 'Batchnorm decay.')
FLAGS = flags.FLAGS
# Constants
......@@ -160,8 +193,18 @@ class ModelOptions(
'divisible_by',
'prediction_with_upsampled_logits',
'dense_prediction_cell_config',
'nas_stem_output_num_conv_filters',
'use_bounded_activation'
'nas_architecture_options',
'use_bounded_activation',
'aspp_with_concat_projection',
'aspp_with_squeeze_and_excitation',
'aspp_convs_filters',
'decoder_use_sum_merge',
'decoder_filters',
'decoder_output_is_logits',
'image_se_uses_qsigmoid',
'label_weights',
'sync_batch_norm_method',
'batch_norm_decay',
])):
"""Immutable class to hold model options."""
......@@ -204,18 +247,45 @@ class ModelOptions(
image_pooling_stride = [1, 1]
if FLAGS.image_pooling_stride:
image_pooling_stride = [int(x) for x in FLAGS.image_pooling_stride]
label_weights = FLAGS.label_weights
if label_weights is None:
label_weights = 1.0
nas_architecture_options = {
'nas_stem_output_num_conv_filters': (
FLAGS.nas_stem_output_num_conv_filters),
'nas_use_classification_head': FLAGS.nas_use_classification_head,
'nas_remove_os32_stride': FLAGS.nas_remove_os32_stride,
}
return super(ModelOptions, cls).__new__(
cls, outputs_to_num_classes, crop_size, atrous_rates, output_stride,
preprocessed_images_dtype, FLAGS.merge_method,
preprocessed_images_dtype,
FLAGS.merge_method,
FLAGS.add_image_level_feature,
image_pooling_crop_size,
image_pooling_stride,
FLAGS.aspp_with_batch_norm,
FLAGS.aspp_with_separable_conv, FLAGS.multi_grid, decoder_output_stride,
FLAGS.decoder_use_separable_conv, FLAGS.logits_kernel_size,
FLAGS.model_variant, FLAGS.depth_multiplier, FLAGS.divisible_by,
FLAGS.prediction_with_upsampled_logits, dense_prediction_cell_config,
FLAGS.nas_stem_output_num_conv_filters, FLAGS.use_bounded_activation)
FLAGS.aspp_with_separable_conv,
FLAGS.multi_grid,
decoder_output_stride,
FLAGS.decoder_use_separable_conv,
FLAGS.logits_kernel_size,
FLAGS.model_variant,
FLAGS.depth_multiplier,
FLAGS.divisible_by,
FLAGS.prediction_with_upsampled_logits,
dense_prediction_cell_config,
nas_architecture_options,
FLAGS.use_bounded_activation,
FLAGS.aspp_with_concat_projection,
FLAGS.aspp_with_squeeze_and_excitation,
FLAGS.aspp_convs_filters,
FLAGS.decoder_use_sum_merge,
FLAGS.decoder_filters,
FLAGS.decoder_output_is_logits,
FLAGS.image_se_uses_qsigmoid,
label_weights,
'None',
FLAGS.batch_norm_decay)
def __deepcopy__(self, memo):
return ModelOptions(copy.deepcopy(self.outputs_to_num_classes),
......
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Augment slim.conv2d with optional Weight Standardization (WS).
WS is a normalization method to accelerate micro-batch training. When used with
Group Normalization and trained with 1 image/GPU, WS is able to match or
outperform the performances of BN trained with large batch sizes.
[1] Siyuan Qiao, Huiyu Wang, Chenxi Liu, Wei Shen, Alan Yuille
Weight Standardization. arXiv:1903.10520
[2] Lei Huang, Xianglong Liu, Yang Liu, Bo Lang, Dacheng Tao
Centered Weight Normalization in Accelerating Training of Deep Neural
Networks. ICCV 2017
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib import layers as contrib_layers
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.layers.python.layers import utils
class Conv2D(tf.keras.layers.Conv2D, tf.layers.Layer):
"""2D convolution layer (e.g. spatial convolution over images).
This layer creates a convolution kernel that is convolved
(actually cross-correlated) with the layer input to produce a tensor of
outputs. If `use_bias` is True (and a `bias_initializer` is provided),
a bias vector is created and added to the outputs. Finally, if
`activation` is not `None`, it is applied to the outputs as well.
"""
def __init__(self,
filters,
kernel_size,
strides=(1, 1),
padding='valid',
data_format='channels_last',
dilation_rate=(1, 1),
activation=None,
use_bias=True,
kernel_initializer=None,
bias_initializer=tf.zeros_initializer(),
kernel_regularizer=None,
bias_regularizer=None,
use_weight_standardization=False,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
trainable=True,
name=None,
**kwargs):
"""Constructs the 2D convolution layer.
Args:
filters: Integer, the dimensionality of the output space (i.e. the number
of filters in the convolution).
kernel_size: An integer or tuple/list of 2 integers, specifying the height
and width of the 2D convolution window. Can be a single integer to
specify the same value for all spatial dimensions.
strides: An integer or tuple/list of 2 integers, specifying the strides of
the convolution along the height and width. Can be a single integer to
specify the same value for all spatial dimensions. Specifying any stride
value != 1 is incompatible with specifying any `dilation_rate` value !=
1.
padding: One of `"valid"` or `"same"` (case-insensitive).
data_format: A string, one of `channels_last` (default) or
`channels_first`. The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape `(batch, height, width,
channels)` while `channels_first` corresponds to inputs with shape
`(batch, channels, height, width)`.
dilation_rate: An integer or tuple/list of 2 integers, specifying the
dilation rate to use for dilated convolution. Can be a single integer to
specify the same value for all spatial dimensions. Currently, specifying
any `dilation_rate` value != 1 is incompatible with specifying any
stride value != 1.
activation: Activation function. Set it to None to maintain a linear
activation.
use_bias: Boolean, whether the layer uses a bias.
kernel_initializer: An initializer for the convolution kernel.
bias_initializer: An initializer for the bias vector. If None, the default
initializer will be used.
kernel_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
use_weight_standardization: Boolean, whether the layer uses weight
standardization.
activity_regularizer: Optional regularizer function for the output.
kernel_constraint: Optional projection function to be applied to the
kernel after being updated by an `Optimizer` (e.g. used to implement
norm constraints or value constraints for layer weights). The function
must take as input the unprojected variable and must return the
projected variable (which must have the same shape). Constraints are not
safe to use when doing asynchronous distributed training.
bias_constraint: Optional projection function to be applied to the bias
after being updated by an `Optimizer`.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
name: A string, the name of the layer.
**kwargs: Arbitrary keyword arguments passed to tf.keras.layers.Conv2D
"""
super(Conv2D, self).__init__(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
trainable=trainable,
name=name,
**kwargs)
self.use_weight_standardization = use_weight_standardization
def call(self, inputs):
if self.use_weight_standardization:
mean, var = tf.nn.moments(self.kernel, [0, 1, 2], keep_dims=True)
kernel = (self.kernel - mean) / tf.sqrt(var + 1e-5)
outputs = self._convolution_op(inputs, kernel)
else:
outputs = self._convolution_op(inputs, self.kernel)
if self.use_bias:
if self.data_format == 'channels_first':
if self.rank == 1:
# tf.nn.bias_add does not accept a 1D input tensor.
bias = tf.reshape(self.bias, (1, self.filters, 1))
outputs += bias
else:
outputs = tf.nn.bias_add(outputs, self.bias, data_format='NCHW')
else:
outputs = tf.nn.bias_add(outputs, self.bias, data_format='NHWC')
if self.activation is not None:
return self.activation(outputs)
return outputs
@contrib_framework.add_arg_scope
def conv2d(inputs,
num_outputs,
kernel_size,
stride=1,
padding='SAME',
data_format=None,
rate=1,
activation_fn=tf.nn.relu,
normalizer_fn=None,
normalizer_params=None,
weights_initializer=contrib_layers.xavier_initializer(),
weights_regularizer=None,
biases_initializer=tf.zeros_initializer(),
biases_regularizer=None,
use_weight_standardization=False,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
scope=None):
"""Adds a 2D convolution followed by an optional batch_norm layer.
`convolution` creates a variable called `weights`, representing the
convolutional kernel, that is convolved (actually cross-correlated) with the
`inputs` to produce a `Tensor` of activations. If a `normalizer_fn` is
provided (such as `batch_norm`), it is then applied. Otherwise, if
`normalizer_fn` is None and a `biases_initializer` is provided then a `biases`
variable would be created and added the activations. Finally, if
`activation_fn` is not `None`, it is applied to the activations as well.
Performs atrous convolution with input stride/dilation rate equal to `rate`
if a value > 1 for any dimension of `rate` is specified. In this case
`stride` values != 1 are not supported.
Args:
inputs: A Tensor of rank N+2 of shape `[batch_size] + input_spatial_shape +
[in_channels]` if data_format does not start with "NC" (default), or
`[batch_size, in_channels] + input_spatial_shape` if data_format starts
with "NC".
num_outputs: Integer, the number of output filters.
kernel_size: A sequence of N positive integers specifying the spatial
dimensions of the filters. Can be a single integer to specify the same
value for all spatial dimensions.
stride: A sequence of N positive integers specifying the stride at which to
compute output. Can be a single integer to specify the same value for all
spatial dimensions. Specifying any `stride` value != 1 is incompatible
with specifying any `rate` value != 1.
padding: One of `"VALID"` or `"SAME"`.
data_format: A string or None. Specifies whether the channel dimension of
the `input` and output is the last dimension (default, or if `data_format`
does not start with "NC"), or the second dimension (if `data_format`
starts with "NC"). For N=1, the valid values are "NWC" (default) and
"NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". For
N=3, the valid values are "NDHWC" (default) and "NCDHW".
rate: A sequence of N positive integers specifying the dilation rate to use
for atrous convolution. Can be a single integer to specify the same value
for all spatial dimensions. Specifying any `rate` value != 1 is
incompatible with specifying any `stride` value != 1.
activation_fn: Activation function. The default value is a ReLU function.
Explicitly set it to None to skip it and maintain a linear activation.
normalizer_fn: Normalization function to use instead of `biases`. If
`normalizer_fn` is provided then `biases_initializer` and
`biases_regularizer` are ignored and `biases` are not created nor added.
default set to None for no normalizer function
normalizer_params: Normalization function parameters.
weights_initializer: An initializer for the weights.
weights_regularizer: Optional regularizer for the weights.
biases_initializer: An initializer for the biases. If None skip biases.
biases_regularizer: Optional regularizer for the biases.
use_weight_standardization: Boolean, whether the layer uses weight
standardization.
reuse: Whether or not the layer and its variables should be reused. To be
able to reuse the layer scope must be given.
variables_collections: Optional list of collections for all the variables or
a dictionary containing a different list of collection per variable.
outputs_collections: Collection to add the outputs.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
scope: Optional scope for `variable_scope`.
Returns:
A tensor representing the output of the operation.
Raises:
ValueError: If `data_format` is invalid.
ValueError: Both 'rate' and `stride` are not uniformly 1.
"""
if data_format not in [None, 'NWC', 'NCW', 'NHWC', 'NCHW', 'NDHWC', 'NCDHW']:
raise ValueError('Invalid data_format: %r' % (data_format,))
# pylint: disable=protected-access
layer_variable_getter = layers._build_variable_getter({
'bias': 'biases',
'kernel': 'weights'
})
# pylint: enable=protected-access
with tf.variable_scope(
scope, 'Conv', [inputs], reuse=reuse,
custom_getter=layer_variable_getter) as sc:
inputs = tf.convert_to_tensor(inputs)
input_rank = inputs.get_shape().ndims
if input_rank != 4:
raise ValueError('Convolution expects input with rank %d, got %d' %
(4, input_rank))
data_format = ('channels_first' if data_format and
data_format.startswith('NC') else 'channels_last')
layer = Conv2D(
filters=num_outputs,
kernel_size=kernel_size,
strides=stride,
padding=padding,
data_format=data_format,
dilation_rate=rate,
activation=None,
use_bias=not normalizer_fn and biases_initializer,
kernel_initializer=weights_initializer,
bias_initializer=biases_initializer,
kernel_regularizer=weights_regularizer,
bias_regularizer=biases_regularizer,
use_weight_standardization=use_weight_standardization,
activity_regularizer=None,
trainable=trainable,
name=sc.name,
dtype=inputs.dtype.base_dtype,
_scope=sc,
_reuse=reuse)
outputs = layer.apply(inputs)
# Add variables to collections.
# pylint: disable=protected-access
layers._add_variable_to_collections(layer.kernel, variables_collections,
'weights')
if layer.use_bias:
layers._add_variable_to_collections(layer.bias, variables_collections,
'biases')
# pylint: enable=protected-access
if normalizer_fn is not None:
normalizer_params = normalizer_params or {}
outputs = normalizer_fn(outputs, **normalizer_params)
if activation_fn is not None:
outputs = activation_fn(outputs)
return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None):
"""Strided 2-D convolution with 'SAME' padding.
When stride > 1, then we do explicit zero-padding, followed by conv2d with
'VALID' padding.
Note that
net = conv2d_same(inputs, num_outputs, 3, stride=stride)
is equivalent to
net = conv2d(inputs, num_outputs, 3, stride=1, padding='SAME')
net = subsample(net, factor=stride)
whereas
net = conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME')
is different when the input's height or width is even, which is why we add the
current function. For more details, see ResnetUtilsTest.testConv2DSameEven().
Args:
inputs: A 4-D tensor of size [batch, height_in, width_in, channels].
num_outputs: An integer, the number of output filters.
kernel_size: An int with the kernel_size of the filters.
stride: An integer, the output stride.
rate: An integer, rate for atrous convolution.
scope: Scope.
Returns:
output: A 4-D tensor of size [batch, height_out, width_out, channels] with
the convolution output.
"""
if stride == 1:
return conv2d(
inputs,
num_outputs,
kernel_size,
stride=1,
rate=rate,
padding='SAME',
scope=scope)
else:
kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
pad_total = kernel_size_effective - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
inputs = tf.pad(inputs,
[[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]])
return conv2d(
inputs,
num_outputs,
kernel_size,
stride=stride,
rate=rate,
padding='VALID',
scope=scope)
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for conv2d_ws."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib import layers as contrib_layers
from deeplab.core import conv2d_ws
class ConvolutionTest(tf.test.TestCase):
def testInvalidShape(self):
with self.cached_session():
images_3d = tf.random_uniform((5, 6, 7, 9, 3), seed=1)
with self.assertRaisesRegexp(
ValueError, 'Convolution expects input with rank 4, got 5'):
conv2d_ws.conv2d(images_3d, 32, 3)
def testInvalidDataFormat(self):
height, width = 7, 9
with self.cached_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
with self.assertRaisesRegexp(ValueError, 'data_format'):
conv2d_ws.conv2d(images, 32, 3, data_format='CHWN')
def testCreateConv(self):
height, width = 7, 9
with self.cached_session():
images = np.random.uniform(size=(5, height, width, 4)).astype(np.float32)
output = conv2d_ws.conv2d(images, 32, [3, 3])
self.assertEqual(output.op.name, 'Conv/Relu')
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
weights = contrib_framework.get_variables_by_name('weights')[0]
self.assertListEqual(weights.get_shape().as_list(), [3, 3, 4, 32])
biases = contrib_framework.get_variables_by_name('biases')[0]
self.assertListEqual(biases.get_shape().as_list(), [32])
def testCreateConvWithWS(self):
height, width = 7, 9
with self.cached_session():
images = np.random.uniform(size=(5, height, width, 4)).astype(np.float32)
output = conv2d_ws.conv2d(
images, 32, [3, 3], use_weight_standardization=True)
self.assertEqual(output.op.name, 'Conv/Relu')
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
weights = contrib_framework.get_variables_by_name('weights')[0]
self.assertListEqual(weights.get_shape().as_list(), [3, 3, 4, 32])
biases = contrib_framework.get_variables_by_name('biases')[0]
self.assertListEqual(biases.get_shape().as_list(), [32])
def testCreateConvNCHW(self):
height, width = 7, 9
with self.cached_session():
images = np.random.uniform(size=(5, 4, height, width)).astype(np.float32)
output = conv2d_ws.conv2d(images, 32, [3, 3], data_format='NCHW')
self.assertEqual(output.op.name, 'Conv/Relu')
self.assertListEqual(output.get_shape().as_list(), [5, 32, height, width])
weights = contrib_framework.get_variables_by_name('weights')[0]
self.assertListEqual(weights.get_shape().as_list(), [3, 3, 4, 32])
biases = contrib_framework.get_variables_by_name('biases')[0]
self.assertListEqual(biases.get_shape().as_list(), [32])
def testCreateSquareConv(self):
height, width = 7, 9
with self.cached_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = conv2d_ws.conv2d(images, 32, 3)
self.assertEqual(output.op.name, 'Conv/Relu')
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
def testCreateConvWithTensorShape(self):
height, width = 7, 9
with self.cached_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = conv2d_ws.conv2d(images, 32, images.get_shape()[1:3])
self.assertEqual(output.op.name, 'Conv/Relu')
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
def testCreateFullyConv(self):
height, width = 7, 9
with self.cached_session():
images = tf.random_uniform((5, height, width, 32), seed=1)
output = conv2d_ws.conv2d(
images, 64, images.get_shape()[1:3], padding='VALID')
self.assertEqual(output.op.name, 'Conv/Relu')
self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 64])
biases = contrib_framework.get_variables_by_name('biases')[0]
self.assertListEqual(biases.get_shape().as_list(), [64])
def testFullyConvWithCustomGetter(self):
height, width = 7, 9
with self.cached_session():
called = [0]
def custom_getter(getter, *args, **kwargs):
called[0] += 1
return getter(*args, **kwargs)
with tf.variable_scope('test', custom_getter=custom_getter):
images = tf.random_uniform((5, height, width, 32), seed=1)
conv2d_ws.conv2d(images, 64, images.get_shape()[1:3])
self.assertEqual(called[0], 2) # Custom getter called twice.
def testCreateVerticalConv(self):
height, width = 7, 9
with self.cached_session():
images = tf.random_uniform((5, height, width, 4), seed=1)
output = conv2d_ws.conv2d(images, 32, [3, 1])
self.assertEqual(output.op.name, 'Conv/Relu')
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
weights = contrib_framework.get_variables_by_name('weights')[0]
self.assertListEqual(weights.get_shape().as_list(), [3, 1, 4, 32])
biases = contrib_framework.get_variables_by_name('biases')[0]
self.assertListEqual(biases.get_shape().as_list(), [32])
def testCreateHorizontalConv(self):
height, width = 7, 9
with self.cached_session():
images = tf.random_uniform((5, height, width, 4), seed=1)
output = conv2d_ws.conv2d(images, 32, [1, 3])
self.assertEqual(output.op.name, 'Conv/Relu')
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
weights = contrib_framework.get_variables_by_name('weights')[0]
self.assertListEqual(weights.get_shape().as_list(), [1, 3, 4, 32])
def testCreateConvWithStride(self):
height, width = 6, 8
with self.cached_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = conv2d_ws.conv2d(images, 32, [3, 3], stride=2)
self.assertEqual(output.op.name, 'Conv/Relu')
self.assertListEqual(output.get_shape().as_list(),
[5, height / 2, width / 2, 32])
def testCreateConvCreatesWeightsAndBiasesVars(self):
height, width = 7, 9
images = tf.random_uniform((5, height, width, 3), seed=1)
with self.cached_session():
self.assertFalse(contrib_framework.get_variables('conv1/weights'))
self.assertFalse(contrib_framework.get_variables('conv1/biases'))
conv2d_ws.conv2d(images, 32, [3, 3], scope='conv1')
self.assertTrue(contrib_framework.get_variables('conv1/weights'))
self.assertTrue(contrib_framework.get_variables('conv1/biases'))
def testCreateConvWithScope(self):
height, width = 7, 9
with self.cached_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = conv2d_ws.conv2d(images, 32, [3, 3], scope='conv1')
self.assertEqual(output.op.name, 'conv1/Relu')
def testCreateConvWithCollection(self):
height, width = 7, 9
images = tf.random_uniform((5, height, width, 3), seed=1)
with tf.name_scope('fe'):
conv = conv2d_ws.conv2d(
images, 32, [3, 3], outputs_collections='outputs', scope='Conv')
output_collected = tf.get_collection('outputs')[0]
self.assertEqual(output_collected.aliases, ['Conv'])
self.assertEqual(output_collected, conv)
def testCreateConvWithoutActivation(self):
height, width = 7, 9
with self.cached_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = conv2d_ws.conv2d(images, 32, [3, 3], activation_fn=None)
self.assertEqual(output.op.name, 'Conv/BiasAdd')
def testCreateConvValid(self):
height, width = 7, 9
with self.cached_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = conv2d_ws.conv2d(images, 32, [3, 3], padding='VALID')
self.assertListEqual(output.get_shape().as_list(), [5, 5, 7, 32])
def testCreateConvWithWD(self):
height, width = 7, 9
weight_decay = 0.01
with self.cached_session() as sess:
images = tf.random_uniform((5, height, width, 3), seed=1)
regularizer = contrib_layers.l2_regularizer(weight_decay)
conv2d_ws.conv2d(images, 32, [3, 3], weights_regularizer=regularizer)
l2_loss = tf.nn.l2_loss(
contrib_framework.get_variables_by_name('weights')[0])
wd = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0]
self.assertEqual(wd.op.name, 'Conv/kernel/Regularizer/l2_regularizer')
sess.run(tf.global_variables_initializer())
self.assertAlmostEqual(sess.run(wd), weight_decay * l2_loss.eval())
def testCreateConvNoRegularizers(self):
height, width = 7, 9
with self.cached_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
conv2d_ws.conv2d(images, 32, [3, 3])
self.assertEqual(
tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES), [])
def testReuseVars(self):
height, width = 7, 9
with self.cached_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
conv2d_ws.conv2d(images, 32, [3, 3], scope='conv1')
self.assertEqual(len(contrib_framework.get_variables()), 2)
conv2d_ws.conv2d(images, 32, [3, 3], scope='conv1', reuse=True)
self.assertEqual(len(contrib_framework.get_variables()), 2)
def testNonReuseVars(self):
height, width = 7, 9
with self.cached_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
conv2d_ws.conv2d(images, 32, [3, 3])
self.assertEqual(len(contrib_framework.get_variables()), 2)
conv2d_ws.conv2d(images, 32, [3, 3])
self.assertEqual(len(contrib_framework.get_variables()), 4)
def testReuseConvWithWD(self):
height, width = 7, 9
with self.cached_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
weight_decay = contrib_layers.l2_regularizer(0.01)
with contrib_framework.arg_scope([conv2d_ws.conv2d],
weights_regularizer=weight_decay):
conv2d_ws.conv2d(images, 32, [3, 3], scope='conv1')
self.assertEqual(len(contrib_framework.get_variables()), 2)
self.assertEqual(
len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
conv2d_ws.conv2d(images, 32, [3, 3], scope='conv1', reuse=True)
self.assertEqual(len(contrib_framework.get_variables()), 2)
self.assertEqual(
len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
def testConvWithBatchNorm(self):
height, width = 7, 9
with self.cached_session():
images = tf.random_uniform((5, height, width, 32), seed=1)
with contrib_framework.arg_scope([conv2d_ws.conv2d],
normalizer_fn=contrib_layers.batch_norm,
normalizer_params={'decay': 0.9}):
net = conv2d_ws.conv2d(images, 32, [3, 3])
net = conv2d_ws.conv2d(net, 32, [3, 3])
self.assertEqual(len(contrib_framework.get_variables()), 8)
self.assertEqual(
len(contrib_framework.get_variables('Conv/BatchNorm')), 3)
self.assertEqual(
len(contrib_framework.get_variables('Conv_1/BatchNorm')), 3)
def testReuseConvWithBatchNorm(self):
height, width = 7, 9
with self.cached_session():
images = tf.random_uniform((5, height, width, 32), seed=1)
with contrib_framework.arg_scope([conv2d_ws.conv2d],
normalizer_fn=contrib_layers.batch_norm,
normalizer_params={'decay': 0.9}):
net = conv2d_ws.conv2d(images, 32, [3, 3], scope='Conv')
net = conv2d_ws.conv2d(net, 32, [3, 3], scope='Conv', reuse=True)
self.assertEqual(len(contrib_framework.get_variables()), 4)
self.assertEqual(
len(contrib_framework.get_variables('Conv/BatchNorm')), 3)
self.assertEqual(
len(contrib_framework.get_variables('Conv_1/BatchNorm')), 0)
def testCreateConvCreatesWeightsAndBiasesVarsWithRateTwo(self):
height, width = 7, 9
images = tf.random_uniform((5, height, width, 3), seed=1)
with self.cached_session():
self.assertFalse(contrib_framework.get_variables('conv1/weights'))
self.assertFalse(contrib_framework.get_variables('conv1/biases'))
conv2d_ws.conv2d(images, 32, [3, 3], rate=2, scope='conv1')
self.assertTrue(contrib_framework.get_variables('conv1/weights'))
self.assertTrue(contrib_framework.get_variables('conv1/biases'))
def testOutputSizeWithRateTwoSamePadding(self):
num_filters = 32
input_size = [5, 10, 12, 3]
expected_size = [5, 10, 12, num_filters]
images = tf.random_uniform(input_size, seed=1)
output = conv2d_ws.conv2d(
images, num_filters, [3, 3], rate=2, padding='SAME')
self.assertListEqual(list(output.get_shape().as_list()), expected_size)
with self.cached_session() as sess:
sess.run(tf.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
def testOutputSizeWithRateTwoValidPadding(self):
num_filters = 32
input_size = [5, 10, 12, 3]
expected_size = [5, 6, 8, num_filters]
images = tf.random_uniform(input_size, seed=1)
output = conv2d_ws.conv2d(
images, num_filters, [3, 3], rate=2, padding='VALID')
self.assertListEqual(list(output.get_shape().as_list()), expected_size)
with self.cached_session() as sess:
sess.run(tf.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
def testOutputSizeWithRateTwoThreeValidPadding(self):
num_filters = 32
input_size = [5, 10, 12, 3]
expected_size = [5, 6, 6, num_filters]
images = tf.random_uniform(input_size, seed=1)
output = conv2d_ws.conv2d(
images, num_filters, [3, 3], rate=[2, 3], padding='VALID')
self.assertListEqual(list(output.get_shape().as_list()), expected_size)
with self.cached_session() as sess:
sess.run(tf.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
def testDynamicOutputSizeWithRateOneValidPadding(self):
num_filters = 32
input_size = [5, 9, 11, 3]
expected_size = [None, None, None, num_filters]
expected_size_dynamic = [5, 7, 9, num_filters]
with self.cached_session():
images = tf.placeholder(np.float32, [None, None, None, input_size[3]])
output = conv2d_ws.conv2d(
images, num_filters, [3, 3], rate=1, padding='VALID')
tf.global_variables_initializer().run()
self.assertEqual(output.op.name, 'Conv/Relu')
self.assertListEqual(output.get_shape().as_list(), expected_size)
eval_output = output.eval({images: np.zeros(input_size, np.float32)})
self.assertListEqual(list(eval_output.shape), expected_size_dynamic)
def testDynamicOutputSizeWithRateOneValidPaddingNCHW(self):
if tf.test.is_gpu_available(cuda_only=True):
num_filters = 32
input_size = [5, 3, 9, 11]
expected_size = [None, num_filters, None, None]
expected_size_dynamic = [5, num_filters, 7, 9]
with self.session(use_gpu=True):
images = tf.placeholder(np.float32, [None, input_size[1], None, None])
output = conv2d_ws.conv2d(
images,
num_filters, [3, 3],
rate=1,
padding='VALID',
data_format='NCHW')
tf.global_variables_initializer().run()
self.assertEqual(output.op.name, 'Conv/Relu')
self.assertListEqual(output.get_shape().as_list(), expected_size)
eval_output = output.eval({images: np.zeros(input_size, np.float32)})
self.assertListEqual(list(eval_output.shape), expected_size_dynamic)
def testDynamicOutputSizeWithRateTwoValidPadding(self):
num_filters = 32
input_size = [5, 9, 11, 3]
expected_size = [None, None, None, num_filters]
expected_size_dynamic = [5, 5, 7, num_filters]
with self.cached_session():
images = tf.placeholder(np.float32, [None, None, None, input_size[3]])
output = conv2d_ws.conv2d(
images, num_filters, [3, 3], rate=2, padding='VALID')
tf.global_variables_initializer().run()
self.assertEqual(output.op.name, 'Conv/Relu')
self.assertListEqual(output.get_shape().as_list(), expected_size)
eval_output = output.eval({images: np.zeros(input_size, np.float32)})
self.assertListEqual(list(eval_output.shape), expected_size_dynamic)
def testWithScope(self):
num_filters = 32
input_size = [5, 9, 11, 3]
expected_size = [5, 5, 7, num_filters]
images = tf.random_uniform(input_size, seed=1)
output = conv2d_ws.conv2d(
images, num_filters, [3, 3], rate=2, padding='VALID', scope='conv7')
with self.cached_session() as sess:
sess.run(tf.global_variables_initializer())
self.assertEqual(output.op.name, 'conv7/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
def testWithScopeWithoutActivation(self):
num_filters = 32
input_size = [5, 9, 11, 3]
expected_size = [5, 5, 7, num_filters]
images = tf.random_uniform(input_size, seed=1)
output = conv2d_ws.conv2d(
images,
num_filters, [3, 3],
rate=2,
padding='VALID',
activation_fn=None,
scope='conv7')
with self.cached_session() as sess:
sess.run(tf.global_variables_initializer())
self.assertEqual(output.op.name, 'conv7/BiasAdd')
self.assertListEqual(list(output.eval().shape), expected_size)
if __name__ == '__main__':
tf.test.main()
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -14,25 +15,32 @@
# ==============================================================================
"""Extracts features for different models."""
import copy
import functools
import tensorflow as tf
from tensorflow.contrib import slim as contrib_slim
from deeplab.core import nas_network
from deeplab.core import resnet_v1_beta
from deeplab.core import xception
from tensorflow.contrib.slim.nets import resnet_utils
from nets.mobilenet import conv_blocks
from nets.mobilenet import mobilenet
from nets.mobilenet import mobilenet_v2
from nets.mobilenet import mobilenet_v3
slim = tf.contrib.slim
slim = contrib_slim
# Default end point for MobileNetv2.
_MOBILENET_V2_FINAL_ENDPOINT = 'layer_18'
_MOBILENET_V3_LARGE_FINAL_ENDPOINT = 'layer_17'
_MOBILENET_V3_SMALL_FINAL_ENDPOINT = 'layer_13'
def _mobilenet_v2(net,
depth_multiplier,
output_stride,
conv_defs=None,
divisible_by=None,
reuse=None,
scope=None,
......@@ -50,6 +58,7 @@ def _mobilenet_v2(net,
if necessary to prevent the network from reducing the spatial resolution
of the activation maps. Allowed values are 8 (accurate fully convolutional
mode), 16 (fast fully convolutional mode), 32 (classification mode).
conv_defs: MobileNet con def.
divisible_by: None (use default setting) or an integer that ensures all
layers # channels will be divisible by this number. Used in MobileNet.
reuse: Reuse model variables.
......@@ -61,11 +70,13 @@ def _mobilenet_v2(net,
"""
if divisible_by is None:
divisible_by = 8 if depth_multiplier == 1.0 else 1
if conv_defs is None:
conv_defs = mobilenet_v2.V2_DEF
with tf.variable_scope(
scope, 'MobilenetV2', [net], reuse=reuse) as scope:
return mobilenet_v2.mobilenet_base(
net,
conv_defs=mobilenet_v2.V2_DEF,
conv_defs=conv_defs,
depth_multiplier=depth_multiplier,
min_depth=8 if depth_multiplier == 1.0 else 1,
divisible_by=divisible_by,
......@@ -74,9 +85,130 @@ def _mobilenet_v2(net,
scope=scope)
def _mobilenet_v3(net,
depth_multiplier,
output_stride,
conv_defs=None,
divisible_by=None,
reuse=None,
scope=None,
final_endpoint=None):
"""Auxiliary function to build mobilenet v3.
Args:
net: Input tensor of shape [batch_size, height, width, channels].
depth_multiplier: Float multiplier for the depth (number of channels)
for all convolution ops. The value must be greater than zero. Typical
usage will be to set this value in (0, 1) to reduce the number of
parameters or computation cost of the model.
output_stride: An integer that specifies the requested ratio of input to
output spatial resolution. If not None, then we invoke atrous convolution
if necessary to prevent the network from reducing the spatial resolution
of the activation maps. Allowed values are 8 (accurate fully convolutional
mode), 16 (fast fully convolutional mode), 32 (classification mode).
conv_defs: A list of ConvDef namedtuples specifying the net architecture.
divisible_by: None (use default setting) or an integer that ensures all
layers # channels will be divisible by this number. Used in MobileNet.
reuse: Reuse model variables.
scope: Optional variable scope.
final_endpoint: The endpoint to construct the network up to.
Returns:
net: The output tensor.
end_points: A set of activations for external use.
Raises:
ValueError: If conv_defs or final_endpoint is not specified.
"""
del divisible_by
with tf.variable_scope(
scope, 'MobilenetV3', [net], reuse=reuse) as scope:
if conv_defs is None:
raise ValueError('conv_defs must be specified for mobilenet v3.')
if final_endpoint is None:
raise ValueError('Final endpoint must be specified for mobilenet v3.')
net, end_points = mobilenet_v3.mobilenet_base(
net,
depth_multiplier=depth_multiplier,
conv_defs=conv_defs,
output_stride=output_stride,
final_endpoint=final_endpoint,
scope=scope)
return net, end_points
def mobilenet_v3_large_seg(net,
depth_multiplier,
output_stride,
divisible_by=None,
reuse=None,
scope=None,
final_endpoint=None):
"""Final mobilenet v3 large model for segmentation task."""
del divisible_by
del final_endpoint
conv_defs = copy.deepcopy(mobilenet_v3.V3_LARGE)
# Reduce the filters by a factor of 2 in the last block.
for layer, expansion in [(13, 336), (14, 480), (15, 480), (16, None)]:
conv_defs['spec'][layer].params['num_outputs'] /= 2
# Update expansion size
if expansion is not None:
factor = expansion / conv_defs['spec'][layer - 1].params['num_outputs']
conv_defs['spec'][layer].params[
'expansion_size'] = mobilenet_v3.expand_input(factor)
return _mobilenet_v3(
net,
depth_multiplier=depth_multiplier,
output_stride=output_stride,
divisible_by=8,
conv_defs=conv_defs,
reuse=reuse,
scope=scope,
final_endpoint=_MOBILENET_V3_LARGE_FINAL_ENDPOINT)
def mobilenet_v3_small_seg(net,
depth_multiplier,
output_stride,
divisible_by=None,
reuse=None,
scope=None,
final_endpoint=None):
"""Final mobilenet v3 small model for segmentation task."""
del divisible_by
del final_endpoint
conv_defs = copy.deepcopy(mobilenet_v3.V3_SMALL)
# Reduce the filters by a factor of 2 in the last block.
for layer, expansion in [(9, 144), (10, 288), (11, 288), (12, None)]:
conv_defs['spec'][layer].params['num_outputs'] /= 2
# Update expansion size
if expansion is not None:
factor = expansion / conv_defs['spec'][layer - 1].params['num_outputs']
conv_defs['spec'][layer].params[
'expansion_size'] = mobilenet_v3.expand_input(factor)
return _mobilenet_v3(
net,
depth_multiplier=depth_multiplier,
output_stride=output_stride,
divisible_by=8,
conv_defs=conv_defs,
reuse=reuse,
scope=scope,
final_endpoint=_MOBILENET_V3_SMALL_FINAL_ENDPOINT)
# A map from network name to network function.
networks_map = {
'mobilenet_v2': _mobilenet_v2,
'mobilenet_v3_large_seg': mobilenet_v3_large_seg,
'mobilenet_v3_small_seg': mobilenet_v3_small_seg,
'resnet_v1_18': resnet_v1_beta.resnet_v1_18,
'resnet_v1_18_beta': resnet_v1_beta.resnet_v1_18_beta,
'resnet_v1_50': resnet_v1_beta.resnet_v1_50,
'resnet_v1_50_beta': resnet_v1_beta.resnet_v1_50_beta,
'resnet_v1_101': resnet_v1_beta.resnet_v1_101,
......@@ -88,13 +220,88 @@ networks_map = {
'nas_hnasnet': nas_network.hnasnet,
}
def mobilenet_v2_arg_scope(is_training=True,
weight_decay=0.00004,
stddev=0.09,
activation=tf.nn.relu6,
bn_decay=0.997,
bn_epsilon=None,
bn_renorm=None):
"""Defines the default MobilenetV2 arg scope.
Args:
is_training: Whether or not we're training the model. If this is set to None
is_training parameter in batch_norm is not set. Please note that this also
sets the is_training parameter in dropout to None.
weight_decay: The weight decay to use for regularizing the model.
stddev: Standard deviation for initialization, if negative uses xavier.
activation: If True, a modified activation is used (initialized ~ReLU6).
bn_decay: decay for the batch norm moving averages.
bn_epsilon: batch normalization epsilon.
bn_renorm: whether to use batchnorm renormalization
Returns:
An `arg_scope` to use for the mobilenet v1 model.
"""
batch_norm_params = {
'center': True,
'scale': True,
'decay': bn_decay,
}
if bn_epsilon is not None:
batch_norm_params['epsilon'] = bn_epsilon
if is_training is not None:
batch_norm_params['is_training'] = is_training
if bn_renorm is not None:
batch_norm_params['renorm'] = bn_renorm
dropout_params = {}
if is_training is not None:
dropout_params['is_training'] = is_training
instance_norm_params = {
'center': True,
'scale': True,
'epsilon': 0.001,
}
if stddev < 0:
weight_intitializer = slim.initializers.xavier_initializer()
else:
weight_intitializer = tf.truncated_normal_initializer(stddev=stddev)
# Set weight_decay for weights in Conv and FC layers.
with slim.arg_scope(
[slim.conv2d, slim.fully_connected, slim.separable_conv2d],
weights_initializer=weight_intitializer,
activation_fn=activation,
normalizer_fn=slim.batch_norm), \
slim.arg_scope(
[conv_blocks.expanded_conv], normalizer_fn=slim.batch_norm), \
slim.arg_scope([mobilenet.apply_activation], activation_fn=activation),\
slim.arg_scope([slim.batch_norm], **batch_norm_params), \
slim.arg_scope([mobilenet.mobilenet_base, mobilenet.mobilenet],
is_training=is_training),\
slim.arg_scope([slim.dropout], **dropout_params), \
slim.arg_scope([slim.instance_norm], **instance_norm_params), \
slim.arg_scope([slim.conv2d], \
weights_regularizer=slim.l2_regularizer(weight_decay)), \
slim.arg_scope([slim.separable_conv2d], weights_regularizer=None), \
slim.arg_scope([slim.conv2d, slim.separable_conv2d], padding='SAME') as s:
return s
# A map from network name to network arg scope.
arg_scopes_map = {
'mobilenet_v2': mobilenet_v2.training_scope,
'resnet_v1_50': resnet_utils.resnet_arg_scope,
'resnet_v1_50_beta': resnet_utils.resnet_arg_scope,
'resnet_v1_101': resnet_utils.resnet_arg_scope,
'resnet_v1_101_beta': resnet_utils.resnet_arg_scope,
'mobilenet_v3_large_seg': mobilenet_v2_arg_scope,
'mobilenet_v3_small_seg': mobilenet_v2_arg_scope,
'resnet_v1_18': resnet_v1_beta.resnet_arg_scope,
'resnet_v1_18_beta': resnet_v1_beta.resnet_arg_scope,
'resnet_v1_50': resnet_v1_beta.resnet_arg_scope,
'resnet_v1_50_beta': resnet_v1_beta.resnet_arg_scope,
'resnet_v1_101': resnet_v1_beta.resnet_arg_scope,
'resnet_v1_101_beta': resnet_v1_beta.resnet_arg_scope,
'xception_41': xception.xception_arg_scope,
'xception_65': xception.xception_arg_scope,
'xception_71': xception.xception_arg_scope,
......@@ -110,54 +317,108 @@ networks_to_feature_maps = {
'mobilenet_v2': {
DECODER_END_POINTS: {
4: ['layer_4/depthwise_output'],
8: ['layer_7/depthwise_output'],
16: ['layer_14/depthwise_output'],
},
},
'mobilenet_v3_large_seg': {
DECODER_END_POINTS: {
4: ['layer_4/depthwise_output'],
8: ['layer_7/depthwise_output'],
16: ['layer_13/depthwise_output'],
},
},
'mobilenet_v3_small_seg': {
DECODER_END_POINTS: {
4: ['layer_2/depthwise_output'],
8: ['layer_4/depthwise_output'],
16: ['layer_9/depthwise_output'],
},
},
'resnet_v1_18': {
DECODER_END_POINTS: {
4: ['block1/unit_1/lite_bottleneck_v1/conv2'],
8: ['block2/unit_1/lite_bottleneck_v1/conv2'],
16: ['block3/unit_1/lite_bottleneck_v1/conv2'],
},
},
'resnet_v1_18_beta': {
DECODER_END_POINTS: {
4: ['block1/unit_1/lite_bottleneck_v1/conv2'],
8: ['block2/unit_1/lite_bottleneck_v1/conv2'],
16: ['block3/unit_1/lite_bottleneck_v1/conv2'],
},
},
'resnet_v1_50': {
DECODER_END_POINTS: {
4: ['block1/unit_2/bottleneck_v1/conv3'],
8: ['block2/unit_3/bottleneck_v1/conv3'],
16: ['block3/unit_5/bottleneck_v1/conv3'],
},
},
'resnet_v1_50_beta': {
DECODER_END_POINTS: {
4: ['block1/unit_2/bottleneck_v1/conv3'],
8: ['block2/unit_3/bottleneck_v1/conv3'],
16: ['block3/unit_5/bottleneck_v1/conv3'],
},
},
'resnet_v1_101': {
DECODER_END_POINTS: {
4: ['block1/unit_2/bottleneck_v1/conv3'],
8: ['block2/unit_3/bottleneck_v1/conv3'],
16: ['block3/unit_22/bottleneck_v1/conv3'],
},
},
'resnet_v1_101_beta': {
DECODER_END_POINTS: {
4: ['block1/unit_2/bottleneck_v1/conv3'],
8: ['block2/unit_3/bottleneck_v1/conv3'],
16: ['block3/unit_22/bottleneck_v1/conv3'],
},
},
'xception_41': {
DECODER_END_POINTS: {
4: ['entry_flow/block2/unit_1/xception_module/'
'separable_conv2_pointwise'],
8: ['entry_flow/block3/unit_1/xception_module/'
'separable_conv2_pointwise'],
16: ['exit_flow/block1/unit_1/xception_module/'
'separable_conv2_pointwise'],
},
},
'xception_65': {
DECODER_END_POINTS: {
4: ['entry_flow/block2/unit_1/xception_module/'
'separable_conv2_pointwise'],
8: ['entry_flow/block3/unit_1/xception_module/'
'separable_conv2_pointwise'],
16: ['exit_flow/block1/unit_1/xception_module/'
'separable_conv2_pointwise'],
},
},
'xception_71': {
DECODER_END_POINTS: {
4: ['entry_flow/block3/unit_1/xception_module/'
'separable_conv2_pointwise'],
8: ['entry_flow/block5/unit_1/xception_module/'
'separable_conv2_pointwise'],
16: ['exit_flow/block1/unit_1/xception_module/'
'separable_conv2_pointwise'],
},
},
'nas_pnasnet': {
DECODER_END_POINTS: {
4: ['Stem'],
8: ['Cell_3'],
16: ['Cell_7'],
},
},
'nas_hnasnet': {
DECODER_END_POINTS: {
4: ['Cell_2'],
8: ['Cell_5'],
16: ['Cell_7'],
},
},
}
......@@ -166,6 +427,10 @@ networks_to_feature_maps = {
# ImageNet pretrained versions of these models.
name_scope = {
'mobilenet_v2': 'MobilenetV2',
'mobilenet_v3_large_seg': 'MobilenetV3',
'mobilenet_v3_small_seg': 'MobilenetV3',
'resnet_v1_18': 'resnet_v1_18',
'resnet_v1_18_beta': 'resnet_v1_18',
'resnet_v1_50': 'resnet_v1_50',
'resnet_v1_50_beta': 'resnet_v1_50',
'resnet_v1_101': 'resnet_v1_101',
......@@ -199,6 +464,10 @@ def _preprocess_zero_mean_unit_range(inputs, dtype=tf.float32):
_PREPROCESS_FN = {
'mobilenet_v2': _preprocess_zero_mean_unit_range,
'mobilenet_v3_large_seg': _preprocess_zero_mean_unit_range,
'mobilenet_v3_small_seg': _preprocess_zero_mean_unit_range,
'resnet_v1_18': _preprocess_subtract_imagenet_mean,
'resnet_v1_18_beta': _preprocess_zero_mean_unit_range,
'resnet_v1_50': _preprocess_subtract_imagenet_mean,
'resnet_v1_50_beta': _preprocess_zero_mean_unit_range,
'resnet_v1_101': _preprocess_subtract_imagenet_mean,
......@@ -252,7 +521,7 @@ def extract_features(images,
preprocessed_images_dtype=tf.float32,
num_classes=None,
global_pool=False,
nas_stem_output_num_conv_filters=20,
nas_architecture_options=None,
nas_training_hyper_parameters=None,
use_bounded_activation=False):
"""Extracts features by the particular model_variant.
......@@ -282,8 +551,11 @@ def extract_features(images,
to None for dense prediction tasks.
global_pool: Global pooling for image classification task. Defaults to
False, since dense prediction tasks do not use this.
nas_stem_output_num_conv_filters: Number of filters of the NAS stem output
tensor.
nas_architecture_options: A dictionary storing NAS architecture options.
It is either None or its kerys are:
- `nas_stem_output_num_conv_filters`: Number of filters of the NAS stem
output tensor.
- `nas_use_classification_head`: Boolean, use image classification head.
nas_training_hyper_parameters: A dictionary storing hyper-parameters for
training nas models. It is either None or its keys are:
- `drop_path_keep_prob`: Probability to keep each path in the cell when
......@@ -339,7 +611,7 @@ def extract_features(images,
multi_grid=multi_grid,
reuse=reuse,
scope=name_scope[model_variant])
elif 'mobilenet' in model_variant:
elif 'mobilenet' in model_variant or model_variant.startswith('mnas'):
arg_scope = arg_scopes_map[model_variant](
is_training=(is_training and fine_tune_batch_norm),
weight_decay=weight_decay)
......@@ -364,7 +636,7 @@ def extract_features(images,
is_training=(is_training and fine_tune_batch_norm),
global_pool=global_pool,
output_stride=output_stride,
nas_stem_output_num_conv_filters=nas_stem_output_num_conv_filters,
nas_architecture_options=nas_architecture_options,
nas_training_hyper_parameters=nas_training_hyper_parameters,
reuse=reuse,
scope=name_scope[model_variant])
......
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -19,32 +20,50 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from six.moves import range
from six.moves import zip
import tensorflow as tf
from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib import slim as contrib_slim
from deeplab.core import xception as xception_utils
from deeplab.core.utils import resize_bilinear
from deeplab.core.utils import scale_dimension
from tensorflow.contrib.slim.nets import resnet_utils
arg_scope = tf.contrib.framework.arg_scope
slim = tf.contrib.slim
arg_scope = contrib_framework.arg_scope
slim = contrib_slim
separable_conv2d_same = functools.partial(xception_utils.separable_conv2d_same,
regularize_depthwise=True)
class NASBaseCell(object):
"""NASNet Cell class that is used as a 'layer' in image architectures.
See https://arxiv.org/abs/1707.07012 and https://arxiv.org/abs/1712.00559.
Args:
num_conv_filters: The number of filters for each convolution operation.
operations: List of operations that are performed in the NASNet Cell in
order.
used_hiddenstates: Binary array that signals if the hiddenstate was used
within the cell. This is used to determine what outputs of the cell
should be concatenated together.
hiddenstate_indices: Determines what hiddenstates should be combined
together with the specified operations to create the NASNet cell.
"""
"""NASNet Cell class that is used as a 'layer' in image architectures."""
def __init__(self, num_conv_filters, operations, used_hiddenstates,
hiddenstate_indices, drop_path_keep_prob, total_num_cells,
total_training_steps):
total_training_steps, batch_norm_fn=slim.batch_norm):
"""Init function.
For more details about NAS cell, see
https://arxiv.org/abs/1707.07012 and https://arxiv.org/abs/1712.00559.
Args:
num_conv_filters: The number of filters for each convolution operation.
operations: List of operations that are performed in the NASNet Cell in
order.
used_hiddenstates: Binary array that signals if the hiddenstate was used
within the cell. This is used to determine what outputs of the cell
should be concatenated together.
hiddenstate_indices: Determines what hiddenstates should be combined
together with the specified operations to create the NASNet cell.
drop_path_keep_prob: Float, drop path keep probability.
total_num_cells: Integer, total number of cells.
total_training_steps: Integer, total training steps.
batch_norm_fn: Function, batch norm function. Defaults to
slim.batch_norm.
"""
if len(hiddenstate_indices) != len(operations):
raise ValueError(
'Number of hiddenstate_indices and operations should be the same.')
......@@ -57,6 +76,7 @@ class NASBaseCell(object):
self._drop_path_keep_prob = drop_path_keep_prob
self._total_num_cells = total_num_cells
self._total_training_steps = total_training_steps
self._batch_norm_fn = batch_norm_fn
def __call__(self, net, scope, filter_scaling, stride, prev_layer, cell_num):
"""Runs the conv cell."""
......@@ -100,11 +120,11 @@ class NASBaseCell(object):
if filter_size != prev_layer.shape[3]:
prev_layer = tf.nn.relu(prev_layer)
prev_layer = slim.conv2d(prev_layer, filter_size, 1, scope='prev_1x1')
prev_layer = slim.batch_norm(prev_layer, scope='prev_bn')
prev_layer = self._batch_norm_fn(prev_layer, scope='prev_bn')
net = tf.nn.relu(net)
net = slim.conv2d(net, filter_size, 1, scope='1x1')
net = slim.batch_norm(net, scope='beginning_bn')
net = self._batch_norm_fn(net, scope='beginning_bn')
net = tf.split(axis=3, num_or_size_splits=1, value=net)
net.append(prev_layer)
return net
......@@ -121,14 +141,14 @@ class NASBaseCell(object):
kernel_size = int(operation.split('x')[0][-1])
for layer_num in range(num_layers):
net = tf.nn.relu(net)
net = slim.separable_conv2d(
net = separable_conv2d_same(
net,
filter_size,
kernel_size,
depth_multiplier=1,
scope='separable_{0}x{0}_{1}'.format(kernel_size, layer_num + 1),
stride=stride)
net = slim.batch_norm(
net = self._batch_norm_fn(
net, scope='bn_sep_{0}x{0}_{1}'.format(kernel_size, layer_num + 1))
stride = 1
elif 'atrous' in operation:
......@@ -138,17 +158,19 @@ class NASBaseCell(object):
scaled_height = scale_dimension(tf.shape(net)[1], 0.5)
scaled_width = scale_dimension(tf.shape(net)[2], 0.5)
net = resize_bilinear(net, [scaled_height, scaled_width], net.dtype)
net = slim.conv2d(net, filter_size, kernel_size, rate=1,
scope='atrous_{0}x{0}'.format(kernel_size))
net = resnet_utils.conv2d_same(
net, filter_size, kernel_size, rate=1, stride=1,
scope='atrous_{0}x{0}'.format(kernel_size))
else:
net = slim.conv2d(net, filter_size, kernel_size, rate=2,
scope='atrous_{0}x{0}'.format(kernel_size))
net = slim.batch_norm(net, scope='bn_atr_{0}x{0}'.format(kernel_size))
net = resnet_utils.conv2d_same(
net, filter_size, kernel_size, rate=2, stride=1,
scope='atrous_{0}x{0}'.format(kernel_size))
net = self._batch_norm_fn(net, scope='bn_atr_{0}x{0}'.format(kernel_size))
elif operation in ['none']:
if stride > 1 or (input_filters != filter_size):
net = tf.nn.relu(net)
net = slim.conv2d(net, filter_size, 1, stride=stride, scope='1x1')
net = slim.batch_norm(net, scope='bn_1')
net = self._batch_norm_fn(net, scope='bn_1')
elif 'pool' in operation:
pooling_type = operation.split('_')[0]
pooling_shape = int(operation.split('_')[-1].split('x')[0])
......@@ -160,7 +182,7 @@ class NASBaseCell(object):
raise ValueError('Unimplemented pooling type: ', pooling_type)
if input_filters != filter_size:
net = slim.conv2d(net, filter_size, 1, stride=1, scope='1x1')
net = slim.batch_norm(net, scope='bn_1')
net = self._batch_norm_fn(net, scope='bn_1')
else:
raise ValueError('Unimplemented operation', operation)
......@@ -176,7 +198,7 @@ class NASBaseCell(object):
net = tf.concat(values=states_to_combine, axis=3)
return net
@tf.contrib.framework.add_arg_scope
@contrib_framework.add_arg_scope
def _apply_drop_path(self, net):
"""Apply drop_path regularization."""
drop_path_keep_prob = self._drop_path_keep_prob
......
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -18,15 +19,17 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib import slim as contrib_slim
from deeplab.core import nas_cell
slim = contrib_slim
class PNASCell(nas_cell.NASBaseCell):
"""Configuration and construction of the PNASNet-5 Cell."""
def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells,
total_training_steps):
total_training_steps, batch_norm_fn=slim.batch_norm):
# Name of operations: op_kernel-size_num-layers.
operations = [
'separable_5x5_2', 'max_pool_3x3', 'separable_7x7_2', 'max_pool_3x3',
......@@ -38,4 +41,5 @@ class PNASCell(nas_cell.NASBaseCell):
super(PNASCell, self).__init__(
num_conv_filters, operations, used_hiddenstates, hiddenstate_indices,
drop_path_keep_prob, total_num_cells, total_training_steps)
drop_path_keep_prob, total_num_cells, total_training_steps,
batch_norm_fn)
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -33,21 +34,28 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import range
import tensorflow as tf
from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib import layers as contrib_layers
from tensorflow.contrib import slim as contrib_slim
from tensorflow.contrib import training as contrib_training
from deeplab.core import nas_genotypes
from deeplab.core import utils
from deeplab.core.nas_cell import NASBaseCell
from deeplab.core.utils import resize_bilinear
from deeplab.core.utils import scale_dimension
from tensorflow.contrib.slim.nets import resnet_utils
arg_scope = tf.contrib.framework.arg_scope
slim = tf.contrib.slim
arg_scope = contrib_framework.arg_scope
slim = contrib_slim
resize_bilinear = utils.resize_bilinear
scale_dimension = utils.scale_dimension
def config(num_conv_filters=20,
total_training_steps=500000,
drop_path_keep_prob=1.0):
return tf.contrib.training.HParams(
return contrib_training.HParams(
# Multiplier when spatial size is reduced by 2.
filter_scaling_rate=2.0,
# Number of filters of the stem output tensor.
......@@ -59,8 +67,10 @@ def config(num_conv_filters=20,
)
def nas_arg_scope(weight_decay=4e-5, batch_norm_decay=0.9997,
batch_norm_epsilon=0.001):
def nas_arg_scope(weight_decay=4e-5,
batch_norm_decay=0.9997,
batch_norm_epsilon=0.001,
sync_batch_norm_method='None'):
"""Default arg scope for the NAS models."""
batch_norm_params = {
# Decay for the moving averages.
......@@ -68,11 +78,11 @@ def nas_arg_scope(weight_decay=4e-5, batch_norm_decay=0.9997,
# epsilon to prevent 0s in variance.
'epsilon': batch_norm_epsilon,
'scale': True,
'fused': True,
}
weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
weights_initializer = tf.contrib.layers.variance_scaling_initializer(
factor=1/3.0, mode='FAN_IN', uniform=True)
batch_norm = utils.get_batch_norm_fn(sync_batch_norm_method)
weights_regularizer = contrib_layers.l2_regularizer(weight_decay)
weights_initializer = contrib_layers.variance_scaling_initializer(
factor=1 / 3.0, mode='FAN_IN', uniform=True)
with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d],
weights_regularizer=weights_regularizer,
weights_initializer=weights_initializer):
......@@ -80,24 +90,22 @@ def nas_arg_scope(weight_decay=4e-5, batch_norm_decay=0.9997,
activation_fn=None, scope='FC'):
with arg_scope([slim.conv2d, slim.separable_conv2d],
activation_fn=None, biases_initializer=None):
with arg_scope([slim.batch_norm], **batch_norm_params) as sc:
with arg_scope([batch_norm], **batch_norm_params) as sc:
return sc
def _nas_stem(inputs):
def _nas_stem(inputs,
batch_norm_fn=slim.batch_norm):
"""Stem used for NAS models."""
net = slim.conv2d(inputs, 64, [3, 3], stride=2,
scope='conv0', padding='SAME')
net = slim.batch_norm(net, scope='conv0_bn')
net = resnet_utils.conv2d_same(inputs, 64, 3, stride=2, scope='conv0')
net = batch_norm_fn(net, scope='conv0_bn')
net = tf.nn.relu(net)
net = slim.conv2d(net, 64, [3, 3], stride=1,
scope='conv1', padding='SAME')
net = slim.batch_norm(net, scope='conv1_bn')
net = resnet_utils.conv2d_same(net, 64, 3, stride=1, scope='conv1')
net = batch_norm_fn(net, scope='conv1_bn')
cell_outputs = [net]
net = tf.nn.relu(net)
net = slim.conv2d(net, 128, [3, 3], stride=2,
scope='conv2', padding='SAME')
net = slim.batch_norm(net, scope='conv2_bn')
net = resnet_utils.conv2d_same(net, 128, 3, stride=2, scope='conv2')
net = batch_norm_fn(net, scope='conv2_bn')
cell_outputs.append(net)
return net, cell_outputs
......@@ -108,9 +116,13 @@ def _build_nas_base(images,
num_classes,
hparams,
global_pool=False,
output_stride=16,
nas_use_classification_head=False,
reuse=None,
scope=None,
final_endpoint=None):
final_endpoint=None,
batch_norm_fn=slim.batch_norm,
nas_remove_os32_stride=False):
"""Constructs a NAS model.
Args:
......@@ -123,15 +135,22 @@ def _build_nas_base(images,
hparams: Hyperparameters needed to construct the network.
global_pool: If True, we perform global average pooling before computing the
logits. Set to True for image classification, False for dense prediction.
output_stride: Interger, the stride of output feature maps.
nas_use_classification_head: Boolean, use image classification head.
reuse: Whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given.
scope: Optional variable_scope.
final_endpoint: The endpoint to construct the network up to.
batch_norm_fn: Batch norm function.
nas_remove_os32_stride: Boolean, remove stride in output_stride 32 branch.
Returns:
net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
end_points: A dictionary from components of the network to the corresponding
activation.
Raises:
ValueError: If output_stride is not a multiple of backbone output stride.
"""
with tf.variable_scope(scope, 'nas', [images], reuse=reuse):
end_points = {}
......@@ -139,7 +158,8 @@ def _build_nas_base(images,
end_points[endpoint_name] = net
return final_endpoint and (endpoint_name == final_endpoint)
net, cell_outputs = _nas_stem(images)
net, cell_outputs = _nas_stem(images,
batch_norm_fn=batch_norm_fn)
if add_and_check_endpoint('Stem', net):
return net, end_points
......@@ -154,11 +174,18 @@ def _build_nas_base(images,
else:
if backbone[cell_num] == backbone[cell_num - 1] + 1:
stride = 2
if backbone[cell_num] == 3 and nas_remove_os32_stride:
stride = 1
filter_scaling *= hparams.filter_scaling_rate
elif backbone[cell_num] == backbone[cell_num - 1] - 1:
scaled_height = scale_dimension(net.shape[1].value, 2)
scaled_width = scale_dimension(net.shape[2].value, 2)
net = resize_bilinear(net, [scaled_height, scaled_width], net.dtype)
if backbone[cell_num - 1] == 3 and nas_remove_os32_stride:
# No need to rescale features.
pass
else:
# Scale features by a factor of 2.
scaled_height = scale_dimension(net.shape[1].value, 2)
scaled_width = scale_dimension(net.shape[2].value, 2)
net = resize_bilinear(net, [scaled_height, scaled_width], net.dtype)
filter_scaling /= hparams.filter_scaling_rate
net = cell(
net,
......@@ -172,11 +199,48 @@ def _build_nas_base(images,
cell_outputs.append(net)
net = tf.nn.relu(net)
if nas_use_classification_head:
# Add image classification head.
# We will expand the filters for different output_strides.
output_stride_to_expanded_filters = {8: 256, 16: 512, 32: 1024}
current_output_scale = 2 + backbone[-1]
current_output_stride = 2 ** current_output_scale
if output_stride % current_output_stride != 0:
raise ValueError(
'output_stride must be a multiple of backbone output stride.')
output_stride //= current_output_stride
rate = 1
if current_output_stride != 32:
num_downsampling = 5 - current_output_scale
for i in range(num_downsampling):
# Gradually donwsample feature maps to output stride = 32.
target_output_stride = 2 ** (current_output_scale + 1 + i)
target_filters = output_stride_to_expanded_filters[
target_output_stride]
scope = 'downsample_os{}'.format(target_output_stride)
if output_stride != 1:
stride = 2
output_stride //= 2
else:
stride = 1
rate *= 2
net = resnet_utils.conv2d_same(
net, target_filters, 3, stride=stride, rate=rate,
scope=scope + '_conv')
net = batch_norm_fn(net, scope=scope + '_bn')
add_and_check_endpoint(scope, net)
net = tf.nn.relu(net)
# Apply 1x1 convolution to expand dimension to 2048.
scope = 'classification_head'
net = slim.conv2d(net, 2048, 1, scope=scope + '_conv')
net = batch_norm_fn(net, scope=scope + '_bn')
add_and_check_endpoint(scope, net)
net = tf.nn.relu(net)
if global_pool:
# Global average pooling.
net = tf.reduce_mean(net, [1, 2], name='global_pool', keepdims=True)
if num_classes is not None:
net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
net = slim.conv2d(net, num_classes, 1, activation_fn=None,
normalizer_fn=None, scope='logits')
end_points['predictions'] = slim.softmax(net, scope='predictions')
return net, end_points
......@@ -187,13 +251,18 @@ def pnasnet(images,
is_training=True,
global_pool=False,
output_stride=16,
nas_stem_output_num_conv_filters=20,
nas_architecture_options=None,
nas_training_hyper_parameters=None,
reuse=None,
scope='pnasnet',
final_endpoint=None):
final_endpoint=None,
sync_batch_norm_method='None'):
"""Builds PNASNet model."""
hparams = config(num_conv_filters=nas_stem_output_num_conv_filters)
if nas_architecture_options is None:
raise ValueError(
'Using NAS model variants. nas_architecture_options cannot be None.')
hparams = config(num_conv_filters=nas_architecture_options[
'nas_stem_output_num_conv_filters'])
if nas_training_hyper_parameters:
hparams.set_hparam('drop_path_keep_prob',
nas_training_hyper_parameters['drop_path_keep_prob'])
......@@ -211,11 +280,13 @@ def pnasnet(images,
backbone = [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]
else:
raise ValueError('Unsupported output_stride ', output_stride)
batch_norm = utils.get_batch_norm_fn(sync_batch_norm_method)
cell = nas_genotypes.PNASCell(hparams.num_conv_filters,
hparams.drop_path_keep_prob,
len(backbone),
hparams.total_training_steps)
with arg_scope([slim.dropout, slim.batch_norm], is_training=is_training):
hparams.total_training_steps,
batch_norm_fn=batch_norm)
with arg_scope([slim.dropout, batch_norm], is_training=is_training):
return _build_nas_base(
images,
cell=cell,
......@@ -223,9 +294,15 @@ def pnasnet(images,
num_classes=num_classes,
hparams=hparams,
global_pool=global_pool,
output_stride=output_stride,
nas_use_classification_head=nas_architecture_options[
'nas_use_classification_head'],
reuse=reuse,
scope=scope,
final_endpoint=final_endpoint)
final_endpoint=final_endpoint,
batch_norm_fn=batch_norm,
nas_remove_os32_stride=nas_architecture_options[
'nas_remove_os32_stride'])
# pylint: disable=unused-argument
......@@ -233,14 +310,19 @@ def hnasnet(images,
num_classes,
is_training=True,
global_pool=False,
output_stride=16,
nas_stem_output_num_conv_filters=20,
output_stride=8,
nas_architecture_options=None,
nas_training_hyper_parameters=None,
reuse=None,
scope='hnasnet',
final_endpoint=None):
final_endpoint=None,
sync_batch_norm_method='None'):
"""Builds hierarchical model."""
hparams = config(num_conv_filters=nas_stem_output_num_conv_filters)
if nas_architecture_options is None:
raise ValueError(
'Using NAS model variants. nas_architecture_options cannot be None.')
hparams = config(num_conv_filters=nas_architecture_options[
'nas_stem_output_num_conv_filters'])
if nas_training_hyper_parameters:
hparams.set_hparam('drop_path_keep_prob',
nas_training_hyper_parameters['drop_path_keep_prob'])
......@@ -258,14 +340,16 @@ def hnasnet(images,
used_hiddenstates = [1, 1, 0, 0, 0, 0, 0]
hiddenstate_indices = [1, 0, 1, 0, 3, 1, 4, 2, 3, 5]
backbone = [0, 0, 0, 1, 2, 1, 2, 2, 3, 3, 2, 1]
batch_norm = utils.get_batch_norm_fn(sync_batch_norm_method)
cell = NASBaseCell(hparams.num_conv_filters,
operations,
used_hiddenstates,
hiddenstate_indices,
hparams.drop_path_keep_prob,
len(backbone),
hparams.total_training_steps)
with arg_scope([slim.dropout, slim.batch_norm], is_training=is_training):
hparams.total_training_steps,
batch_norm_fn=batch_norm)
with arg_scope([slim.dropout, batch_norm], is_training=is_training):
return _build_nas_base(
images,
cell=cell,
......@@ -273,6 +357,12 @@ def hnasnet(images,
num_classes=num_classes,
hparams=hparams,
global_pool=global_pool,
output_stride=output_stride,
nas_use_classification_head=nas_architecture_options[
'nas_use_classification_head'],
reuse=reuse,
scope=scope,
final_endpoint=final_endpoint)
final_endpoint=final_endpoint,
batch_norm_fn=batch_norm,
nas_remove_os32_stride=nas_architecture_options[
'nas_remove_os32_stride'])
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -21,12 +22,15 @@ from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib import slim as contrib_slim
from tensorflow.contrib import training as contrib_training
from deeplab.core import nas_genotypes
from deeplab.core import nas_network
arg_scope = tf.contrib.framework.arg_scope
slim = tf.contrib.slim
arg_scope = contrib_framework.arg_scope
slim = contrib_slim
def create_test_input(batch, height, width, channels):
......@@ -54,7 +58,7 @@ class NASNetworkTest(tf.test.TestCase):
output_stride=16,
final_endpoint=None):
"""Build PNASNet model backbone."""
hparams = tf.contrib.training.HParams(
hparams = contrib_training.HParams(
filter_scaling_rate=2.0,
num_conv_filters=10,
drop_path_keep_prob=1.0,
......
This diff is collapsed.
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -22,12 +23,14 @@ from __future__ import print_function
import functools
import numpy as np
import six
import tensorflow as tf
from tensorflow.contrib import slim as contrib_slim
from deeplab.core import resnet_v1_beta
from tensorflow.contrib.slim.nets import resnet_utils
slim = tf.contrib.slim
slim = contrib_slim
def create_test_input(batch, height, width, channels):
......@@ -47,6 +50,43 @@ def create_test_input(batch, height, width, channels):
class ResnetCompleteNetworkTest(tf.test.TestCase):
"""Tests with complete small ResNet v1 networks."""
def _resnet_small_lite_bottleneck(self,
inputs,
num_classes=None,
is_training=True,
global_pool=True,
output_stride=None,
multi_grid=None,
reuse=None,
scope='resnet_v1_small'):
"""A shallow and thin ResNet v1 with lite_bottleneck."""
if multi_grid is None:
multi_grid = [1, 1]
else:
if len(multi_grid) != 2:
raise ValueError('Expect multi_grid to have length 2.')
block = resnet_v1_beta.resnet_v1_small_beta_block
blocks = [
block('block1', base_depth=1, num_units=1, stride=2),
block('block2', base_depth=2, num_units=1, stride=2),
block('block3', base_depth=4, num_units=1, stride=2),
resnet_utils.Block('block4', resnet_v1_beta.lite_bottleneck, [
{'depth': 8,
'stride': 1,
'unit_rate': rate} for rate in multi_grid])]
return resnet_v1_beta.resnet_v1_beta(
inputs,
blocks,
num_classes=num_classes,
is_training=is_training,
global_pool=global_pool,
output_stride=output_stride,
root_block_fn=functools.partial(
resnet_v1_beta.root_block_fn_for_beta_variant,
depth_multiplier=0.25),
reuse=reuse,
scope=scope)
def _resnet_small(self,
inputs,
num_classes=None,
......@@ -65,13 +105,11 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
block = resnet_v1_beta.resnet_v1_beta_block
blocks = [
block('block1', base_depth=1, num_units=3, stride=2),
block('block2', base_depth=2, num_units=3, stride=2),
block('block3', base_depth=4, num_units=3, stride=2),
block('block1', base_depth=1, num_units=1, stride=2),
block('block2', base_depth=2, num_units=1, stride=2),
block('block3', base_depth=4, num_units=1, stride=2),
resnet_utils.Block('block4', resnet_v1_beta.bottleneck, [
{'depth': 32,
'depth_bottleneck': 8,
'stride': 1,
{'depth': 32, 'depth_bottleneck': 8, 'stride': 1,
'unit_rate': rate} for rate in multi_grid])]
return resnet_v1_beta.resnet_v1_beta(
......@@ -86,6 +124,199 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
reuse=reuse,
scope=scope)
def testClassificationEndPointsWithLiteBottleneck(self):
global_pool = True
num_classes = 10
inputs = create_test_input(2, 224, 224, 3)
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
logits, end_points = self._resnet_small_lite_bottleneck(
inputs,
num_classes,
global_pool=global_pool,
scope='resnet')
self.assertTrue(logits.op.name.startswith('resnet/logits'))
self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
self.assertIn('predictions', end_points)
self.assertListEqual(end_points['predictions'].get_shape().as_list(),
[2, 1, 1, num_classes])
def testClassificationEndPointsWithMultigridAndLiteBottleneck(self):
global_pool = True
num_classes = 10
inputs = create_test_input(2, 224, 224, 3)
multi_grid = [1, 2]
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
logits, end_points = self._resnet_small_lite_bottleneck(
inputs,
num_classes,
global_pool=global_pool,
multi_grid=multi_grid,
scope='resnet')
self.assertTrue(logits.op.name.startswith('resnet/logits'))
self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
self.assertIn('predictions', end_points)
self.assertListEqual(end_points['predictions'].get_shape().as_list(),
[2, 1, 1, num_classes])
def testClassificationShapesWithLiteBottleneck(self):
global_pool = True
num_classes = 10
inputs = create_test_input(2, 224, 224, 3)
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
_, end_points = self._resnet_small_lite_bottleneck(
inputs,
num_classes,
global_pool=global_pool,
scope='resnet')
endpoint_to_shape = {
'resnet/conv1_1': [2, 112, 112, 16],
'resnet/conv1_2': [2, 112, 112, 16],
'resnet/conv1_3': [2, 112, 112, 32],
'resnet/block1': [2, 28, 28, 1],
'resnet/block2': [2, 14, 14, 2],
'resnet/block3': [2, 7, 7, 4],
'resnet/block4': [2, 7, 7, 8]}
for endpoint, shape in six.iteritems(endpoint_to_shape):
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
def testFullyConvolutionalEndpointShapesWithLiteBottleneck(self):
global_pool = False
num_classes = 10
inputs = create_test_input(2, 321, 321, 3)
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
_, end_points = self._resnet_small_lite_bottleneck(
inputs,
num_classes,
global_pool=global_pool,
scope='resnet')
endpoint_to_shape = {
'resnet/conv1_1': [2, 161, 161, 16],
'resnet/conv1_2': [2, 161, 161, 16],
'resnet/conv1_3': [2, 161, 161, 32],
'resnet/block1': [2, 41, 41, 1],
'resnet/block2': [2, 21, 21, 2],
'resnet/block3': [2, 11, 11, 4],
'resnet/block4': [2, 11, 11, 8]}
for endpoint, shape in six.iteritems(endpoint_to_shape):
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
def testAtrousFullyConvolutionalEndpointShapesWithLiteBottleneck(self):
global_pool = False
num_classes = 10
output_stride = 8
inputs = create_test_input(2, 321, 321, 3)
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
_, end_points = self._resnet_small_lite_bottleneck(
inputs,
num_classes,
global_pool=global_pool,
output_stride=output_stride,
scope='resnet')
endpoint_to_shape = {
'resnet/conv1_1': [2, 161, 161, 16],
'resnet/conv1_2': [2, 161, 161, 16],
'resnet/conv1_3': [2, 161, 161, 32],
'resnet/block1': [2, 41, 41, 1],
'resnet/block2': [2, 41, 41, 2],
'resnet/block3': [2, 41, 41, 4],
'resnet/block4': [2, 41, 41, 8]}
for endpoint, shape in six.iteritems(endpoint_to_shape):
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
def testAtrousFullyConvolutionalValuesWithLiteBottleneck(self):
"""Verify dense feature extraction with atrous convolution."""
nominal_stride = 32
for output_stride in [4, 8, 16, 32, None]:
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
with tf.Graph().as_default():
with self.test_session() as sess:
tf.set_random_seed(0)
inputs = create_test_input(2, 81, 81, 3)
# Dense feature extraction followed by subsampling.
output, _ = self._resnet_small_lite_bottleneck(
inputs,
None,
is_training=False,
global_pool=False,
output_stride=output_stride)
if output_stride is None:
factor = 1
else:
factor = nominal_stride // output_stride
output = resnet_utils.subsample(output, factor)
# Make the two networks use the same weights.
tf.get_variable_scope().reuse_variables()
# Feature extraction at the nominal network rate.
expected, _ = self._resnet_small_lite_bottleneck(
inputs,
None,
is_training=False,
global_pool=False)
sess.run(tf.global_variables_initializer())
self.assertAllClose(output.eval(), expected.eval(),
atol=1e-4, rtol=1e-4)
def testUnknownBatchSizeWithLiteBottleneck(self):
batch = 2
height, width = 65, 65
global_pool = True
num_classes = 10
inputs = create_test_input(None, height, width, 3)
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
logits, _ = self._resnet_small_lite_bottleneck(
inputs,
num_classes,
global_pool=global_pool,
scope='resnet')
self.assertTrue(logits.op.name.startswith('resnet/logits'))
self.assertListEqual(logits.get_shape().as_list(),
[None, 1, 1, num_classes])
images = create_test_input(batch, height, width, 3)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(logits, {inputs: images.eval()})
self.assertEqual(output.shape, (batch, 1, 1, num_classes))
def testFullyConvolutionalUnknownHeightWidthWithLiteBottleneck(self):
batch = 2
height, width = 65, 65
global_pool = False
inputs = create_test_input(batch, None, None, 3)
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
output, _ = self._resnet_small_lite_bottleneck(
inputs,
None,
global_pool=global_pool)
self.assertListEqual(output.get_shape().as_list(),
[batch, None, None, 8])
images = create_test_input(batch, height, width, 3)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output, {inputs: images.eval()})
self.assertEqual(output.shape, (batch, 3, 3, 8))
def testAtrousFullyConvolutionalUnknownHeightWidthWithLiteBottleneck(self):
batch = 2
height, width = 65, 65
global_pool = False
output_stride = 8
inputs = create_test_input(batch, None, None, 3)
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
output, _ = self._resnet_small_lite_bottleneck(
inputs,
None,
global_pool=global_pool,
output_stride=output_stride)
self.assertListEqual(output.get_shape().as_list(),
[batch, None, None, 8])
images = create_test_input(batch, height, width, 3)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output, {inputs: images.eval()})
self.assertEqual(output.shape, (batch, 9, 9, 8))
def testClassificationEndPoints(self):
global_pool = True
num_classes = 10
......@@ -98,7 +329,66 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
self.assertTrue(logits.op.name.startswith('resnet/logits'))
self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
self.assertTrue('predictions' in end_points)
self.assertIn('predictions', end_points)
self.assertListEqual(end_points['predictions'].get_shape().as_list(),
[2, 1, 1, num_classes])
def testClassificationEndPointsWithWS(self):
global_pool = True
num_classes = 10
inputs = create_test_input(2, 224, 224, 3)
with slim.arg_scope(
resnet_v1_beta.resnet_arg_scope(use_weight_standardization=True)):
logits, end_points = self._resnet_small(
inputs, num_classes, global_pool=global_pool, scope='resnet')
self.assertTrue(logits.op.name.startswith('resnet/logits'))
self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
self.assertIn('predictions', end_points)
self.assertListEqual(end_points['predictions'].get_shape().as_list(),
[2, 1, 1, num_classes])
def testClassificationEndPointsWithGN(self):
global_pool = True
num_classes = 10
inputs = create_test_input(2, 224, 224, 3)
with slim.arg_scope(
resnet_v1_beta.resnet_arg_scope(normalization_method='group')):
with slim.arg_scope([slim.group_norm], groups=1):
logits, end_points = self._resnet_small(
inputs, num_classes, global_pool=global_pool, scope='resnet')
self.assertTrue(logits.op.name.startswith('resnet/logits'))
self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
self.assertIn('predictions', end_points)
self.assertListEqual(end_points['predictions'].get_shape().as_list(),
[2, 1, 1, num_classes])
def testInvalidGroupsWithGN(self):
global_pool = True
num_classes = 10
inputs = create_test_input(2, 224, 224, 3)
with self.assertRaisesRegexp(ValueError, 'Invalid groups'):
with slim.arg_scope(
resnet_v1_beta.resnet_arg_scope(normalization_method='group')):
with slim.arg_scope([slim.group_norm], groups=32):
_, _ = self._resnet_small(
inputs, num_classes, global_pool=global_pool, scope='resnet')
def testClassificationEndPointsWithGNWS(self):
global_pool = True
num_classes = 10
inputs = create_test_input(2, 224, 224, 3)
with slim.arg_scope(
resnet_v1_beta.resnet_arg_scope(
normalization_method='group', use_weight_standardization=True)):
with slim.arg_scope([slim.group_norm], groups=1):
logits, end_points = self._resnet_small(
inputs, num_classes, global_pool=global_pool, scope='resnet')
self.assertTrue(logits.op.name.startswith('resnet/logits'))
self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
self.assertIn('predictions', end_points)
self.assertListEqual(end_points['predictions'].get_shape().as_list(),
[2, 1, 1, num_classes])
......@@ -116,7 +406,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
self.assertTrue(logits.op.name.startswith('resnet/logits'))
self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
self.assertTrue('predictions' in end_points)
self.assertIn('predictions', end_points)
self.assertListEqual(end_points['predictions'].get_shape().as_list(),
[2, 1, 1, num_classes])
......@@ -137,7 +427,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
'resnet/block2': [2, 14, 14, 8],
'resnet/block3': [2, 7, 7, 16],
'resnet/block4': [2, 7, 7, 32]}
for endpoint, shape in endpoint_to_shape.iteritems():
for endpoint, shape in six.iteritems(endpoint_to_shape):
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
def testFullyConvolutionalEndpointShapes(self):
......@@ -157,7 +447,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
'resnet/block2': [2, 21, 21, 8],
'resnet/block3': [2, 11, 11, 16],
'resnet/block4': [2, 11, 11, 32]}
for endpoint, shape in endpoint_to_shape.iteritems():
for endpoint, shape in six.iteritems(endpoint_to_shape):
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
def testAtrousFullyConvolutionalEndpointShapes(self):
......@@ -179,7 +469,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
'resnet/block2': [2, 41, 41, 8],
'resnet/block3': [2, 41, 41, 16],
'resnet/block4': [2, 41, 41, 32]}
for endpoint, shape in endpoint_to_shape.iteritems():
for endpoint, shape in six.iteritems(endpoint_to_shape):
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
def testAtrousFullyConvolutionalValues(self):
......@@ -231,7 +521,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(logits, {inputs: images.eval()})
self.assertEquals(output.shape, (batch, 1, 1, num_classes))
self.assertEqual(output.shape, (batch, 1, 1, num_classes))
def testFullyConvolutionalUnknownHeightWidth(self):
batch = 2
......@@ -248,7 +538,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output, {inputs: images.eval()})
self.assertEquals(output.shape, (batch, 3, 3, 32))
self.assertEqual(output.shape, (batch, 3, 3, 32))
def testAtrousFullyConvolutionalUnknownHeightWidth(self):
batch = 2
......@@ -267,7 +557,7 @@ class ResnetCompleteNetworkTest(tf.test.TestCase):
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output, {inputs: images.eval()})
self.assertEquals(output.shape, (batch, 9, 9, 32))
self.assertEqual(output.shape, (batch, 9, 9, 32))
if __name__ == '__main__':
......
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -15,8 +16,14 @@
"""This script contains utility functions."""
import tensorflow as tf
from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib import slim as contrib_slim
slim = tf.contrib.slim
slim = contrib_slim
# Quantized version of sigmoid function.
q_sigmoid = lambda x: tf.nn.relu6(x + 3) * 0.16667
def resize_bilinear(images, size, output_dtype=tf.float32):
......@@ -98,3 +105,110 @@ def split_separable_conv2d(inputs,
stddev=pointwise_weights_initializer_stddev),
weights_regularizer=slim.l2_regularizer(weight_decay),
scope=scope + '_pointwise')
def get_label_weight_mask(labels, ignore_label, num_classes, label_weights=1.0):
"""Gets the label weight mask.
Args:
labels: A Tensor of labels with the shape of [-1].
ignore_label: Integer, label to ignore.
num_classes: Integer, the number of semantic classes.
label_weights: A float or a list of weights. If it is a float, it means all
the labels have the same weight. If it is a list of weights, then each
element in the list represents the weight for the label of its index, for
example, label_weights = [0.1, 0.5] means the weight for label 0 is 0.1
and the weight for label 1 is 0.5.
Returns:
A Tensor of label weights with the same shape of labels, each element is the
weight for the label with the same index in labels and the element is 0.0
if the label is to ignore.
Raises:
ValueError: If label_weights is neither a float nor a list, or if
label_weights is a list and its length is not equal to num_classes.
"""
if not isinstance(label_weights, (float, list)):
raise ValueError(
'The type of label_weights is invalid, it must be a float or a list.')
if isinstance(label_weights, list) and len(label_weights) != num_classes:
raise ValueError(
'Length of label_weights must be equal to num_classes if it is a list, '
'label_weights: %s, num_classes: %d.' % (label_weights, num_classes))
not_ignore_mask = tf.not_equal(labels, ignore_label)
not_ignore_mask = tf.cast(not_ignore_mask, tf.float32)
if isinstance(label_weights, float):
return not_ignore_mask * label_weights
label_weights = tf.constant(label_weights, tf.float32)
weight_mask = tf.einsum('...y,y->...',
tf.one_hot(labels, num_classes, dtype=tf.float32),
label_weights)
return tf.multiply(not_ignore_mask, weight_mask)
def get_batch_norm_fn(sync_batch_norm_method):
"""Gets batch norm function.
Currently we only support the following methods:
- `None` (no sync batch norm). We use slim.batch_norm in this case.
Args:
sync_batch_norm_method: String, method used to sync batch norm.
Returns:
Batchnorm function.
Raises:
ValueError: If sync_batch_norm_method is not supported.
"""
if sync_batch_norm_method == 'None':
return slim.batch_norm
else:
raise ValueError('Unsupported sync_batch_norm_method.')
def get_batch_norm_params(decay=0.9997,
epsilon=1e-5,
center=True,
scale=True,
is_training=True,
sync_batch_norm_method='None',
initialize_gamma_as_zeros=False):
"""Gets batch norm parameters.
Args:
decay: Float, decay for the moving average.
epsilon: Float, value added to variance to avoid dividing by zero.
center: Boolean. If True, add offset of `beta` to normalized tensor. If
False,`beta` is ignored.
scale: Boolean. If True, multiply by `gamma`. If False, `gamma` is not used.
is_training: Boolean, whether or not the layer is in training mode.
sync_batch_norm_method: String, method used to sync batch norm.
initialize_gamma_as_zeros: Boolean, initializing `gamma` as zeros or not.
Returns:
A dictionary for batchnorm parameters.
Raises:
ValueError: If sync_batch_norm_method is not supported.
"""
batch_norm_params = {
'is_training': is_training,
'decay': decay,
'epsilon': epsilon,
'scale': scale,
'center': center,
}
if initialize_gamma_as_zeros:
if sync_batch_norm_method == 'None':
# Slim-type gamma_initialier.
batch_norm_params['param_initializers'] = {
'gamma': tf.zeros_initializer(),
}
else:
raise ValueError('Unsupported sync_batch_norm_method.')
return batch_norm_params
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -14,6 +15,7 @@
# ==============================================================================
"""Tests for utils.py."""
import numpy as np
import tensorflow as tf
from deeplab.core import utils
......@@ -26,6 +28,63 @@ class UtilsTest(tf.test.TestCase):
self.assertEqual(193, utils.scale_dimension(321, 0.6))
self.assertEqual(241, utils.scale_dimension(321, 0.75))
def testGetLabelWeightMask_withFloatLabelWeights(self):
labels = tf.constant([0, 4, 1, 3, 2])
ignore_label = 4
num_classes = 5
label_weights = 0.5
expected_label_weight_mask = np.array([0.5, 0.0, 0.5, 0.5, 0.5],
dtype=np.float32)
with self.test_session() as sess:
label_weight_mask = utils.get_label_weight_mask(
labels, ignore_label, num_classes, label_weights=label_weights)
label_weight_mask = sess.run(label_weight_mask)
self.assertAllEqual(label_weight_mask, expected_label_weight_mask)
def testGetLabelWeightMask_withListLabelWeights(self):
labels = tf.constant([0, 4, 1, 3, 2])
ignore_label = 4
num_classes = 5
label_weights = [0.0, 0.1, 0.2, 0.3, 0.4]
expected_label_weight_mask = np.array([0.0, 0.0, 0.1, 0.3, 0.2],
dtype=np.float32)
with self.test_session() as sess:
label_weight_mask = utils.get_label_weight_mask(
labels, ignore_label, num_classes, label_weights=label_weights)
label_weight_mask = sess.run(label_weight_mask)
self.assertAllEqual(label_weight_mask, expected_label_weight_mask)
def testGetLabelWeightMask_withInvalidLabelWeightsType(self):
labels = tf.constant([0, 4, 1, 3, 2])
ignore_label = 4
num_classes = 5
self.assertRaisesWithRegexpMatch(
ValueError,
'^The type of label_weights is invalid, it must be a float or a list',
utils.get_label_weight_mask,
labels=labels,
ignore_label=ignore_label,
num_classes=num_classes,
label_weights=None)
def testGetLabelWeightMask_withInvalidLabelWeightsLength(self):
labels = tf.constant([0, 4, 1, 3, 2])
ignore_label = 4
num_classes = 5
label_weights = [0.0, 0.1, 0.2]
self.assertRaisesWithRegexpMatch(
ValueError,
'^Length of label_weights must be equal to num_classes if it is a list',
utils.get_label_weight_mask,
labels=labels,
ignore_label=ignore_label,
num_classes=num_classes,
label_weights=label_weights)
if __name__ == '__main__':
tf.test.main()
This diff is collapsed.
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -17,11 +18,12 @@
import numpy as np
import six
import tensorflow as tf
from tensorflow.contrib import slim as contrib_slim
from deeplab.core import xception
from tensorflow.contrib.slim.nets import resnet_utils
slim = tf.contrib.slim
slim = contrib_slim
def create_test_input(batch, height, width, channels):
......@@ -29,13 +31,14 @@ def create_test_input(batch, height, width, channels):
if None in [batch, height, width, channels]:
return tf.placeholder(tf.float32, (batch, height, width, channels))
else:
return tf.to_float(
return tf.cast(
np.tile(
np.reshape(
np.reshape(np.arange(height), [height, 1]) +
np.reshape(np.arange(width), [1, width]),
[1, height, width, 1]),
[batch, 1, 1, channels]))
[batch, 1, 1, channels]),
tf.float32)
class UtilityFunctionTest(tf.test.TestCase):
......@@ -58,15 +61,15 @@ class UtilityFunctionTest(tf.test.TestCase):
y1 = slim.separable_conv2d(x, 1, [3, 3], depth_multiplier=1,
stride=1, scope='Conv')
y1_expected = tf.to_float([[14, 28, 43, 26],
[28, 48, 66, 37],
[43, 66, 84, 46],
[26, 37, 46, 22]])
y1_expected = tf.cast([[14, 28, 43, 26],
[28, 48, 66, 37],
[43, 66, 84, 46],
[26, 37, 46, 22]], tf.float32)
y1_expected = tf.reshape(y1_expected, [1, n, n, 1])
y2 = resnet_utils.subsample(y1, 2)
y2_expected = tf.to_float([[14, 43],
[43, 84]])
y2_expected = tf.cast([[14, 43],
[43, 84]], tf.float32)
y2_expected = tf.reshape(y2_expected, [1, n2, n2, 1])
y3 = xception.separable_conv2d_same(x, 1, 3, depth_multiplier=1,
......@@ -76,8 +79,8 @@ class UtilityFunctionTest(tf.test.TestCase):
y4 = slim.separable_conv2d(x, 1, [3, 3], depth_multiplier=1,
stride=2, scope='Conv')
y4_expected = tf.to_float([[48, 37],
[37, 22]])
y4_expected = tf.cast([[48, 37],
[37, 22]], tf.float32)
y4_expected = tf.reshape(y4_expected, [1, n2, n2, 1])
with self.test_session() as sess:
......@@ -105,17 +108,17 @@ class UtilityFunctionTest(tf.test.TestCase):
y1 = slim.separable_conv2d(x, 1, [3, 3], depth_multiplier=1,
stride=1, scope='Conv')
y1_expected = tf.to_float([[14, 28, 43, 58, 34],
[28, 48, 66, 84, 46],
[43, 66, 84, 102, 55],
[58, 84, 102, 120, 64],
[34, 46, 55, 64, 30]])
y1_expected = tf.cast([[14, 28, 43, 58, 34],
[28, 48, 66, 84, 46],
[43, 66, 84, 102, 55],
[58, 84, 102, 120, 64],
[34, 46, 55, 64, 30]], tf.float32)
y1_expected = tf.reshape(y1_expected, [1, n, n, 1])
y2 = resnet_utils.subsample(y1, 2)
y2_expected = tf.to_float([[14, 43, 34],
[43, 84, 55],
[34, 55, 30]])
y2_expected = tf.cast([[14, 43, 34],
[43, 84, 55],
[34, 55, 30]], tf.float32)
y2_expected = tf.reshape(y2_expected, [1, n2, n2, 1])
y3 = xception.separable_conv2d_same(x, 1, 3, depth_multiplier=1,
......@@ -211,8 +214,8 @@ class XceptionNetworkTest(tf.test.TestCase):
def testClassificationEndPoints(self):
global_pool = True
num_classes = 10
inputs = create_test_input(2, 224, 224, 3)
num_classes = 3
inputs = create_test_input(2, 32, 32, 3)
with slim.arg_scope(xception.xception_arg_scope()):
logits, end_points = self._xception_small(
inputs,
......@@ -231,8 +234,8 @@ class XceptionNetworkTest(tf.test.TestCase):
def testEndpointNames(self):
global_pool = True
num_classes = 10
inputs = create_test_input(2, 224, 224, 3)
num_classes = 3
inputs = create_test_input(2, 32, 32, 3)
with slim.arg_scope(xception.xception_arg_scope()):
_, end_points = self._xception_small(
inputs,
......@@ -290,12 +293,12 @@ class XceptionNetworkTest(tf.test.TestCase):
'xception/logits',
'predictions',
]
self.assertItemsEqual(end_points.keys(), expected)
self.assertItemsEqual(list(end_points.keys()), expected)
def testClassificationShapes(self):
global_pool = True
num_classes = 10
inputs = create_test_input(2, 224, 224, 3)
num_classes = 3
inputs = create_test_input(2, 64, 64, 3)
with slim.arg_scope(xception.xception_arg_scope()):
_, end_points = self._xception_small(
inputs,
......@@ -303,20 +306,20 @@ class XceptionNetworkTest(tf.test.TestCase):
global_pool=global_pool,
scope='xception')
endpoint_to_shape = {
'xception/entry_flow/conv1_1': [2, 112, 112, 32],
'xception/entry_flow/block1': [2, 56, 56, 1],
'xception/entry_flow/block2': [2, 28, 28, 2],
'xception/entry_flow/block4': [2, 14, 14, 4],
'xception/middle_flow/block1': [2, 14, 14, 4],
'xception/exit_flow/block1': [2, 7, 7, 8],
'xception/exit_flow/block2': [2, 7, 7, 16]}
'xception/entry_flow/conv1_1': [2, 32, 32, 32],
'xception/entry_flow/block1': [2, 16, 16, 1],
'xception/entry_flow/block2': [2, 8, 8, 2],
'xception/entry_flow/block4': [2, 4, 4, 4],
'xception/middle_flow/block1': [2, 4, 4, 4],
'xception/exit_flow/block1': [2, 2, 2, 8],
'xception/exit_flow/block2': [2, 2, 2, 16]}
for endpoint, shape in six.iteritems(endpoint_to_shape):
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
def testFullyConvolutionalEndpointShapes(self):
global_pool = False
num_classes = 10
inputs = create_test_input(2, 321, 321, 3)
num_classes = 3
inputs = create_test_input(2, 65, 65, 3)
with slim.arg_scope(xception.xception_arg_scope()):
_, end_points = self._xception_small(
inputs,
......@@ -324,21 +327,21 @@ class XceptionNetworkTest(tf.test.TestCase):
global_pool=global_pool,
scope='xception')
endpoint_to_shape = {
'xception/entry_flow/conv1_1': [2, 161, 161, 32],
'xception/entry_flow/block1': [2, 81, 81, 1],
'xception/entry_flow/block2': [2, 41, 41, 2],
'xception/entry_flow/block4': [2, 21, 21, 4],
'xception/middle_flow/block1': [2, 21, 21, 4],
'xception/exit_flow/block1': [2, 11, 11, 8],
'xception/exit_flow/block2': [2, 11, 11, 16]}
'xception/entry_flow/conv1_1': [2, 33, 33, 32],
'xception/entry_flow/block1': [2, 17, 17, 1],
'xception/entry_flow/block2': [2, 9, 9, 2],
'xception/entry_flow/block4': [2, 5, 5, 4],
'xception/middle_flow/block1': [2, 5, 5, 4],
'xception/exit_flow/block1': [2, 3, 3, 8],
'xception/exit_flow/block2': [2, 3, 3, 16]}
for endpoint, shape in six.iteritems(endpoint_to_shape):
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
def testAtrousFullyConvolutionalEndpointShapes(self):
global_pool = False
num_classes = 10
num_classes = 3
output_stride = 8
inputs = create_test_input(2, 321, 321, 3)
inputs = create_test_input(2, 65, 65, 3)
with slim.arg_scope(xception.xception_arg_scope()):
_, end_points = self._xception_small(
inputs,
......@@ -347,12 +350,12 @@ class XceptionNetworkTest(tf.test.TestCase):
output_stride=output_stride,
scope='xception')
endpoint_to_shape = {
'xception/entry_flow/block1': [2, 81, 81, 1],
'xception/entry_flow/block2': [2, 41, 41, 2],
'xception/entry_flow/block4': [2, 41, 41, 4],
'xception/middle_flow/block1': [2, 41, 41, 4],
'xception/exit_flow/block1': [2, 41, 41, 8],
'xception/exit_flow/block2': [2, 41, 41, 16]}
'xception/entry_flow/block1': [2, 17, 17, 1],
'xception/entry_flow/block2': [2, 9, 9, 2],
'xception/entry_flow/block4': [2, 9, 9, 4],
'xception/middle_flow/block1': [2, 9, 9, 4],
'xception/exit_flow/block1': [2, 9, 9, 8],
'xception/exit_flow/block2': [2, 9, 9, 16]}
for endpoint, shape in six.iteritems(endpoint_to_shape):
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
......@@ -460,15 +463,15 @@ class XceptionNetworkTest(tf.test.TestCase):
inputs,
num_classes=10,
reuse=True)
self.assertItemsEqual(end_points0.keys(), end_points1.keys())
self.assertItemsEqual(list(end_points0.keys()), list(end_points1.keys()))
def testUseBoundedAcitvation(self):
global_pool = False
num_classes = 10
output_stride = 8
num_classes = 3
output_stride = 16
for use_bounded_activation in (True, False):
tf.reset_default_graph()
inputs = create_test_input(2, 321, 321, 3)
inputs = create_test_input(2, 65, 65, 3)
with slim.arg_scope(xception.xception_arg_scope(
use_bounded_activation=use_bounded_activation)):
_, _ = self._xception_small(
......
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -52,11 +53,12 @@ Alan L. Yuille (* equal contribution)
(https://arxiv.org/abs/1412.7062)
"""
import tensorflow as tf
from tensorflow.contrib import slim as contrib_slim
from deeplab.core import dense_prediction_cell
from deeplab.core import feature_extractor
from deeplab.core import utils
slim = tf.contrib.slim
slim = contrib_slim
LOGITS_SCOPE_NAME = 'logits'
MERGED_LOGITS_SCOPE = 'merged_logits'
......@@ -66,6 +68,8 @@ CONCAT_PROJECTION_SCOPE = 'concat_projection'
DECODER_SCOPE = 'decoder'
META_ARCHITECTURE_SCOPE = 'meta_architecture'
PROB_SUFFIX = '_prob'
_resize_bilinear = utils.resize_bilinear
scale_dimension = utils.scale_dimension
split_separable_conv2d = utils.split_separable_conv2d
......@@ -158,6 +162,7 @@ def predict_labels_multi_scale(images,
# Compute average prediction across different scales and flipped images.
predictions = tf.reduce_mean(tf.concat(predictions, 4), axis=4)
outputs_to_predictions[output] = tf.argmax(predictions, 3)
predictions[output + PROB_SUFFIX] = tf.nn.softmax(predictions)
return outputs_to_predictions
......@@ -195,6 +200,7 @@ def predict_labels(images, model_options, image_pyramid=None):
tf.shape(images)[1:3],
scales_to_logits[MERGED_LOGITS_SCOPE].dtype)
predictions[output] = tf.argmax(logits, 3)
predictions[output + PROB_SUFFIX] = tf.nn.softmax(logits)
else:
argmax_results = tf.argmax(logits, 3)
argmax_results = tf.image.resize_nearest_neighbor(
......@@ -203,7 +209,11 @@ def predict_labels(images, model_options, image_pyramid=None):
align_corners=True,
name='resize_prediction')
predictions[output] = tf.squeeze(argmax_results, 3)
predictions[output + PROB_SUFFIX] = tf.image.resize_bilinear(
tf.nn.softmax(logits),
tf.shape(images)[1:3],
align_corners=True,
name='resize_prob')
return predictions
......@@ -389,8 +399,7 @@ def extract_features(images,
is_training=is_training,
preprocessed_images_dtype=model_options.preprocessed_images_dtype,
fine_tune_batch_norm=fine_tune_batch_norm,
nas_stem_output_num_conv_filters=(
model_options.nas_stem_output_num_conv_filters),
nas_architecture_options=model_options.nas_architecture_options,
nas_training_hyper_parameters=nas_training_hyper_parameters,
use_bounded_activation=model_options.use_bounded_activation)
......@@ -419,26 +428,26 @@ def extract_features(images,
# could express the ASPP module as one particular dense prediction
# cell architecture. We do not do so but leave the following codes
# for backward compatibility.
batch_norm_params = {
'is_training': is_training and fine_tune_batch_norm,
'decay': 0.9997,
'epsilon': 1e-5,
'scale': True,
}
batch_norm_params = utils.get_batch_norm_params(
decay=0.9997,
epsilon=1e-5,
scale=True,
is_training=(is_training and fine_tune_batch_norm),
sync_batch_norm_method=model_options.sync_batch_norm_method)
batch_norm = utils.get_batch_norm_fn(
model_options.sync_batch_norm_method)
activation_fn = (
tf.nn.relu6 if model_options.use_bounded_activation else tf.nn.relu)
with slim.arg_scope(
[slim.conv2d, slim.separable_conv2d],
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=activation_fn,
normalizer_fn=slim.batch_norm,
normalizer_fn=batch_norm,
padding='SAME',
stride=1,
reuse=reuse):
with slim.arg_scope([slim.batch_norm], **batch_norm_params):
depth = 256
with slim.arg_scope([batch_norm], **batch_norm_params):
depth = model_options.aspp_convs_filters
branch_logits = []
if model_options.add_image_level_feature:
......@@ -470,8 +479,18 @@ def extract_features(images,
features, axis=[1, 2], keepdims=True)
resize_height = pool_height
resize_width = pool_width
image_feature_activation_fn = tf.nn.relu
image_feature_normalizer_fn = batch_norm
if model_options.aspp_with_squeeze_and_excitation:
image_feature_activation_fn = tf.nn.sigmoid
if model_options.image_se_uses_qsigmoid:
image_feature_activation_fn = utils.q_sigmoid
image_feature_normalizer_fn = None
image_feature = slim.conv2d(
image_feature, depth, 1, scope=IMAGE_POOLING_SCOPE)
image_feature, depth, 1,
activation_fn=image_feature_activation_fn,
normalizer_fn=image_feature_normalizer_fn,
scope=IMAGE_POOLING_SCOPE)
image_feature = _resize_bilinear(
image_feature,
[resize_height, resize_width],
......@@ -482,7 +501,8 @@ def extract_features(images,
if isinstance(resize_width, tf.Tensor):
resize_width = None
image_feature.set_shape([None, resize_height, resize_width, depth])
branch_logits.append(image_feature)
if not model_options.aspp_with_squeeze_and_excitation:
branch_logits.append(image_feature)
# Employ a 1x1 convolution.
branch_logits.append(slim.conv2d(features, depth, 1,
......@@ -506,13 +526,17 @@ def extract_features(images,
# Merge branch logits.
concat_logits = tf.concat(branch_logits, 3)
concat_logits = slim.conv2d(
concat_logits, depth, 1, scope=CONCAT_PROJECTION_SCOPE)
concat_logits = slim.dropout(
concat_logits,
keep_prob=0.9,
is_training=is_training,
scope=CONCAT_PROJECTION_SCOPE + '_dropout')
if model_options.aspp_with_concat_projection:
concat_logits = slim.conv2d(
concat_logits, depth, 1, scope=CONCAT_PROJECTION_SCOPE)
concat_logits = slim.dropout(
concat_logits,
keep_prob=0.9,
is_training=is_training,
scope=CONCAT_PROJECTION_SCOPE + '_dropout')
if (model_options.add_image_level_feature and
model_options.aspp_with_squeeze_and_excitation):
concat_logits *= image_feature
return concat_logits, end_points
......@@ -552,13 +576,19 @@ def _get_logits(images,
fine_tune_batch_norm=fine_tune_batch_norm,
nas_training_hyper_parameters=nas_training_hyper_parameters)
if model_options.decoder_output_stride is not None:
if model_options.decoder_output_stride:
crop_size = model_options.crop_size
if crop_size is None:
crop_size = [tf.shape(images)[1], tf.shape(images)[2]]
features = refine_by_decoder(
features,
end_points,
crop_size=model_options.crop_size,
crop_size=crop_size,
decoder_output_stride=model_options.decoder_output_stride,
decoder_use_separable_conv=model_options.decoder_use_separable_conv,
decoder_use_sum_merge=model_options.decoder_use_sum_merge,
decoder_filters=model_options.decoder_filters,
decoder_output_is_logits=model_options.decoder_output_is_logits,
model_variant=model_options.model_variant,
weight_decay=weight_decay,
reuse=reuse,
......@@ -568,15 +598,19 @@ def _get_logits(images,
outputs_to_logits = {}
for output in sorted(model_options.outputs_to_num_classes):
outputs_to_logits[output] = get_branch_logits(
features,
model_options.outputs_to_num_classes[output],
model_options.atrous_rates,
aspp_with_batch_norm=model_options.aspp_with_batch_norm,
kernel_size=model_options.logits_kernel_size,
weight_decay=weight_decay,
reuse=reuse,
scope_suffix=output)
if model_options.decoder_output_is_logits:
outputs_to_logits[output] = tf.identity(features,
name=output)
else:
outputs_to_logits[output] = get_branch_logits(
features,
model_options.outputs_to_num_classes[output],
model_options.atrous_rates,
aspp_with_batch_norm=model_options.aspp_with_batch_norm,
kernel_size=model_options.logits_kernel_size,
weight_decay=weight_decay,
reuse=reuse,
scope_suffix=output)
return outputs_to_logits
......@@ -586,12 +620,16 @@ def refine_by_decoder(features,
crop_size=None,
decoder_output_stride=None,
decoder_use_separable_conv=False,
decoder_use_sum_merge=False,
decoder_filters=256,
decoder_output_is_logits=False,
model_variant=None,
weight_decay=0.0001,
reuse=None,
is_training=False,
fine_tune_batch_norm=False,
use_bounded_activation=False):
use_bounded_activation=False,
sync_batch_norm_method='None'):
"""Adds the decoder to obtain sharper segmentation results.
Args:
......@@ -604,6 +642,9 @@ def refine_by_decoder(features,
decoder_output_stride: A list of integers specifying the output stride of
low-level features used in the decoder module.
decoder_use_separable_conv: Employ separable convolution for decoder or not.
decoder_use_sum_merge: Boolean, decoder uses simple sum merge or not.
decoder_filters: Integer, decoder filter size.
decoder_output_is_logits: Boolean, using decoder output as logits or not.
model_variant: Model variant for feature extraction.
weight_decay: The weight decay for model variables.
reuse: Reuse the model variables or not.
......@@ -611,6 +652,9 @@ def refine_by_decoder(features,
fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
use_bounded_activation: Whether or not to use bounded activations. Bounded
activations better lend themselves to quantized inference.
sync_batch_norm_method: String, method used to sync batch norm. Currently
only support `None` (no sync batch norm) and `tpu` (use tpu code to
sync batch norm).
Returns:
Decoder output with size [batch, decoder_height, decoder_width,
......@@ -621,22 +665,40 @@ def refine_by_decoder(features,
"""
if crop_size is None:
raise ValueError('crop_size must be provided when using decoder.')
batch_norm_params = {
'is_training': is_training and fine_tune_batch_norm,
'decay': 0.9997,
'epsilon': 1e-5,
'scale': True,
}
batch_norm_params = utils.get_batch_norm_params(
decay=0.9997,
epsilon=1e-5,
scale=True,
is_training=(is_training and fine_tune_batch_norm),
sync_batch_norm_method=sync_batch_norm_method)
batch_norm = utils.get_batch_norm_fn(sync_batch_norm_method)
decoder_depth = decoder_filters
projected_filters = 48
if decoder_use_sum_merge:
# When using sum merge, the projected filters must be equal to decoder
# filters.
projected_filters = decoder_filters
if decoder_output_is_logits:
# Overwrite the setting when decoder output is logits.
activation_fn = None
normalizer_fn = None
conv2d_kernel = 1
# Use original conv instead of separable conv.
decoder_use_separable_conv = False
else:
# Default setting when decoder output is not logits.
activation_fn = tf.nn.relu6 if use_bounded_activation else tf.nn.relu
normalizer_fn = batch_norm
conv2d_kernel = 3
with slim.arg_scope(
[slim.conv2d, slim.separable_conv2d],
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu6 if use_bounded_activation else tf.nn.relu,
normalizer_fn=slim.batch_norm,
activation_fn=activation_fn,
normalizer_fn=normalizer_fn,
padding='SAME',
stride=1,
reuse=reuse):
with slim.arg_scope([slim.batch_norm], **batch_norm_params):
with slim.arg_scope([batch_norm], **batch_norm_params):
with tf.variable_scope(DECODER_SCOPE, DECODER_SCOPE, [features]):
decoder_features = features
decoder_stage = 0
......@@ -652,7 +714,9 @@ def refine_by_decoder(features,
for i, name in enumerate(feature_list):
decoder_features_list = [decoder_features]
# MobileNet and NAS variants use different naming convention.
if 'mobilenet' in model_variant or model_variant.startswith('nas'):
if ('mobilenet' in model_variant or
model_variant.startswith('mnas') or
model_variant.startswith('nas')):
feature_name = name
else:
feature_name = '{}/{}'.format(
......@@ -660,7 +724,7 @@ def refine_by_decoder(features,
decoder_features_list.append(
slim.conv2d(
end_points[feature_name],
48,
projected_filters,
1,
scope='feature_projection' + str(i) + scope_suffix))
# Determine the output size.
......@@ -675,33 +739,115 @@ def refine_by_decoder(features,
w = (None if isinstance(decoder_width, tf.Tensor)
else decoder_width)
decoder_features_list[j].set_shape([None, h, w, None])
decoder_depth = 256
if decoder_use_separable_conv:
decoder_features = split_separable_conv2d(
tf.concat(decoder_features_list, 3),
filters=decoder_depth,
rate=1,
weight_decay=weight_decay,
scope='decoder_conv0' + scope_suffix)
decoder_features = split_separable_conv2d(
decoder_features,
filters=decoder_depth,
rate=1,
if decoder_use_sum_merge:
decoder_features = _decoder_with_sum_merge(
decoder_features_list,
decoder_depth,
conv2d_kernel=conv2d_kernel,
decoder_use_separable_conv=decoder_use_separable_conv,
weight_decay=weight_decay,
scope='decoder_conv1' + scope_suffix)
scope_suffix=scope_suffix)
else:
num_convs = 2
decoder_features = slim.repeat(
tf.concat(decoder_features_list, 3),
num_convs,
slim.conv2d,
if not decoder_use_separable_conv:
scope_suffix = str(i) + scope_suffix
decoder_features = _decoder_with_concat_merge(
decoder_features_list,
decoder_depth,
3,
scope='decoder_conv' + str(i) + scope_suffix)
decoder_use_separable_conv=decoder_use_separable_conv,
weight_decay=weight_decay,
scope_suffix=scope_suffix)
decoder_stage += 1
return decoder_features
def _decoder_with_sum_merge(decoder_features_list,
decoder_depth,
conv2d_kernel=3,
decoder_use_separable_conv=True,
weight_decay=0.0001,
scope_suffix=''):
"""Decoder with sum to merge features.
Args:
decoder_features_list: A list of decoder features.
decoder_depth: Integer, the filters used in the convolution.
conv2d_kernel: Integer, the convolution kernel size.
decoder_use_separable_conv: Boolean, use separable conv or not.
weight_decay: Weight decay for the model variables.
scope_suffix: String, used in the scope suffix.
Returns:
decoder features merged with sum.
Raises:
RuntimeError: If decoder_features_list have length not equal to 2.
"""
if len(decoder_features_list) != 2:
raise RuntimeError('Expect decoder_features has length 2.')
# Only apply one convolution when decoder use sum merge.
if decoder_use_separable_conv:
decoder_features = split_separable_conv2d(
decoder_features_list[0],
filters=decoder_depth,
rate=1,
weight_decay=weight_decay,
scope='decoder_split_sep_conv0'+scope_suffix) + decoder_features_list[1]
else:
decoder_features = slim.conv2d(
decoder_features_list[0],
decoder_depth,
conv2d_kernel,
scope='decoder_conv0'+scope_suffix) + decoder_features_list[1]
return decoder_features
def _decoder_with_concat_merge(decoder_features_list,
decoder_depth,
decoder_use_separable_conv=True,
weight_decay=0.0001,
scope_suffix=''):
"""Decoder with concatenation to merge features.
This decoder method applies two convolutions to smooth the features obtained
by concatenating the input decoder_features_list.
This decoder module is proposed in the DeepLabv3+ paper.
Args:
decoder_features_list: A list of decoder features.
decoder_depth: Integer, the filters used in the convolution.
decoder_use_separable_conv: Boolean, use separable conv or not.
weight_decay: Weight decay for the model variables.
scope_suffix: String, used in the scope suffix.
Returns:
decoder features merged with concatenation.
"""
if decoder_use_separable_conv:
decoder_features = split_separable_conv2d(
tf.concat(decoder_features_list, 3),
filters=decoder_depth,
rate=1,
weight_decay=weight_decay,
scope='decoder_conv0'+scope_suffix)
decoder_features = split_separable_conv2d(
decoder_features,
filters=decoder_depth,
rate=1,
weight_decay=weight_decay,
scope='decoder_conv1'+scope_suffix)
else:
num_convs = 2
decoder_features = slim.repeat(
tf.concat(decoder_features_list, 3),
num_convs,
slim.conv2d,
decoder_depth,
3,
scope='decoder_conv'+scope_suffix)
return decoder_features
def get_branch_logits(features,
num_classes,
atrous_rates=None,
......
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -17,14 +18,20 @@
See model.py for more details and usage.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import tensorflow as tf
from tensorflow.python.ops import math_ops
from tensorflow.contrib import quantize as contrib_quantize
from tensorflow.contrib import tfprof as contrib_tfprof
from deeplab import common
from deeplab import model
from deeplab.datasets import data_generator
from deeplab.utils import train_utils
from deployment import model_deploy
slim = tf.contrib.slim
flags = tf.app.flags
FLAGS = flags.FLAGS
......@@ -74,6 +81,12 @@ flags.DEFINE_string('profile_logdir', None,
# Settings for training strategy.
flags.DEFINE_enum('optimizer', 'momentum', ['momentum', 'adam'],
'Which optimizer to use.')
# Momentum optimizer flags
flags.DEFINE_enum('learning_policy', 'poly', ['poly', 'step'],
'Learning rate policy for training.')
......@@ -82,6 +95,12 @@ flags.DEFINE_enum('learning_policy', 'poly', ['poly', 'step'],
flags.DEFINE_float('base_learning_rate', .0001,
'The base learning rate for model training.')
flags.DEFINE_float('decay_steps', 0.0,
'Decay steps for polynomial learning rate schedule.')
flags.DEFINE_float('end_learning_rate', 0.0,
'End learning rate for polynomial learning rate schedule.')
flags.DEFINE_float('learning_rate_decay_factor', 0.1,
'The rate to decay the base learning rate.')
......@@ -96,6 +115,11 @@ flags.DEFINE_integer('training_number_of_steps', 30000,
flags.DEFINE_float('momentum', 0.9, 'The momentum value to use')
# Adam optimizer flags
flags.DEFINE_float('adam_learning_rate', 0.001,
'Learning rate for the adam optimizer.')
flags.DEFINE_float('adam_epsilon', 1e-08, 'Adam optimizer epsilon.')
# When fine_tune_batch_norm=True, use at least batch size larger than 12
# (batch size more than 16 is better). Otherwise, one could use smaller batch
# size and set fine_tune_batch_norm=False.
......@@ -174,7 +198,6 @@ flags.DEFINE_integer(
'top_k_percent_pixels=0.25, then mining percent will gradually reduce from '
'100% to 25% until 100K steps after which we only mine top 25% pixels.')
flags.DEFINE_float(
'top_k_percent_pixels', 1.0,
'The top k percent pixels (in terms of the loss values) used to compute '
......@@ -240,207 +263,34 @@ def _build_deeplab(iterator, outputs_to_num_classes, ignore_label):
samples[common.LABEL],
num_classes,
ignore_label,
loss_weight=1.0,
loss_weight=model_options.label_weights,
upsample_logits=FLAGS.upsample_logits,
hard_example_mining_step=FLAGS.hard_example_mining_step,
top_k_percent_pixels=FLAGS.top_k_percent_pixels,
scope=output)
# Log the summary
_log_summaries(samples[common.IMAGE], samples[common.LABEL], num_classes,
output_type_dict[model.MERGED_LOGITS_SCOPE])
def _tower_loss(iterator, num_of_classes, ignore_label, scope, reuse_variable):
"""Calculates the total loss on a single tower running the deeplab model.
Args:
iterator: An iterator of type tf.data.Iterator for images and labels.
num_of_classes: Number of classes for the dataset.
ignore_label: Ignore label for the dataset.
scope: Unique prefix string identifying the deeplab tower.
reuse_variable: If the variable should be reused.
Returns:
The total loss for a batch of data.
"""
with tf.variable_scope(
tf.get_variable_scope(), reuse=True if reuse_variable else None):
_build_deeplab(iterator, {common.OUTPUT_TYPE: num_of_classes}, ignore_label)
losses = tf.losses.get_losses(scope=scope)
for loss in losses:
tf.summary.scalar('Losses/%s' % loss.op.name, loss)
regularization_loss = tf.losses.get_regularization_loss(scope=scope)
tf.summary.scalar('Losses/%s' % regularization_loss.op.name,
regularization_loss)
total_loss = tf.add_n([tf.add_n(losses), regularization_loss])
return total_loss
def _average_gradients(tower_grads):
"""Calculates average of gradient for each shared variable across all towers.
Note that this function provides a synchronization point across all towers.
Args:
tower_grads: List of lists of (gradient, variable) tuples. The outer list is
over individual gradients. The inner list is over the gradient calculation
for each tower.
Returns:
List of pairs of (gradient, variable) where the gradient has been summed
across all towers.
"""
average_grads = []
for grad_and_vars in zip(*tower_grads):
# Note that each grad_and_vars looks like the following:
# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
grads, variables = zip(*grad_and_vars)
grad = tf.reduce_mean(tf.stack(grads, axis=0), axis=0)
# All vars are of the same value, using the first tower here.
average_grads.append((grad, variables[0]))
return average_grads
def _log_summaries(input_image, label, num_of_classes, output):
"""Logs the summaries for the model.
Args:
input_image: Input image of the model. Its shape is [batch_size, height,
width, channel].
label: Label of the image. Its shape is [batch_size, height, width].
num_of_classes: The number of classes of the dataset.
output: Output of the model. Its shape is [batch_size, height, width].
"""
# Add summaries for model variables.
for model_var in tf.model_variables():
tf.summary.histogram(model_var.op.name, model_var)
# Add summaries for images, labels, semantic predictions.
if FLAGS.save_summaries_images:
tf.summary.image('samples/%s' % common.IMAGE, input_image)
# Scale up summary image pixel values for better visualization.
pixel_scaling = max(1, 255 // num_of_classes)
summary_label = tf.cast(label * pixel_scaling, tf.uint8)
tf.summary.image('samples/%s' % common.LABEL, summary_label)
predictions = tf.expand_dims(tf.argmax(output, 3), -1)
summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
tf.summary.image('samples/%s' % common.OUTPUT_TYPE, summary_predictions)
def _train_deeplab_model(iterator, num_of_classes, ignore_label):
"""Trains the deeplab model.
Args:
iterator: An iterator of type tf.data.Iterator for images and labels.
num_of_classes: Number of classes for the dataset.
ignore_label: Ignore label for the dataset.
Returns:
train_tensor: A tensor to update the model variables.
summary_op: An operation to log the summaries.
"""
global_step = tf.train.get_or_create_global_step()
learning_rate = train_utils.get_model_learning_rate(
FLAGS.learning_policy, FLAGS.base_learning_rate,
FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
FLAGS.training_number_of_steps, FLAGS.learning_power,
FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
tf.summary.scalar('learning_rate', learning_rate)
optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
tower_losses = []
tower_grads = []
for i in range(FLAGS.num_clones):
with tf.device('/gpu:%d' % i):
# First tower has default name scope.
name_scope = ('clone_%d' % i) if i else ''
with tf.name_scope(name_scope) as scope:
loss = _tower_loss(
iterator=iterator,
num_of_classes=num_of_classes,
ignore_label=ignore_label,
scope=scope,
reuse_variable=(i != 0))
tower_losses.append(loss)
if FLAGS.quantize_delay_step >= 0:
if FLAGS.num_clones > 1:
raise ValueError('Quantization doesn\'t support multi-clone yet.')
tf.contrib.quantize.create_training_graph(
quant_delay=FLAGS.quantize_delay_step)
for i in range(FLAGS.num_clones):
with tf.device('/gpu:%d' % i):
name_scope = ('clone_%d' % i) if i else ''
with tf.name_scope(name_scope) as scope:
grads = optimizer.compute_gradients(tower_losses[i])
tower_grads.append(grads)
with tf.device('/cpu:0'):
grads_and_vars = _average_gradients(tower_grads)
# Modify the gradients for biases and last layer variables.
last_layers = model.get_extra_layer_scopes(
FLAGS.last_layers_contain_logits_only)
grad_mult = train_utils.get_model_gradient_multipliers(
last_layers, FLAGS.last_layer_gradient_multiplier)
if grad_mult:
grads_and_vars = tf.contrib.training.multiply_gradients(
grads_and_vars, grad_mult)
# Create gradient update op.
grad_updates = optimizer.apply_gradients(
grads_and_vars, global_step=global_step)
# Gather update_ops. These contain, for example,
# the updates for the batch_norm variables created by model_fn.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
update_ops.append(grad_updates)
update_op = tf.group(*update_ops)
total_loss = tf.losses.get_total_loss(add_regularization_losses=True)
# Print total loss to the terminal.
# This implementation is mirrored from tf.slim.summaries.
should_log = math_ops.equal(math_ops.mod(global_step, FLAGS.log_steps), 0)
total_loss = tf.cond(
should_log,
lambda: tf.Print(total_loss, [total_loss], 'Total loss is :'),
lambda: total_loss)
tf.summary.scalar('total_loss', total_loss)
with tf.control_dependencies([update_op]):
train_tensor = tf.identity(total_loss, name='train_op')
# Excludes summaries from towers other than the first one.
summary_op = tf.summary.merge_all(scope='(?!clone_)')
return train_tensor, summary_op
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
# Set up deployment (i.e., multi-GPUs and/or multi-replicas).
config = model_deploy.DeploymentConfig(
num_clones=FLAGS.num_clones,
clone_on_cpu=FLAGS.clone_on_cpu,
replica_id=FLAGS.task,
num_replicas=FLAGS.num_replicas,
num_ps_tasks=FLAGS.num_ps_tasks)
# Split the batch across GPUs.
assert FLAGS.train_batch_size % config.num_clones == 0, (
'Training batch size not divisble by number of clones (GPUs).')
clone_batch_size = FLAGS.train_batch_size // config.num_clones
tf.gfile.MakeDirs(FLAGS.train_logdir)
tf.logging.info('Training on %s set', FLAGS.train_split)
graph = tf.Graph()
with graph.as_default():
with tf.device(tf.train.replica_device_setter(ps_tasks=FLAGS.num_ps_tasks)):
assert FLAGS.train_batch_size % FLAGS.num_clones == 0, (
'Training batch size not divisble by number of clones (GPUs).')
clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones
with tf.Graph().as_default() as graph:
with tf.device(config.inputs_device()):
dataset = data_generator.Dataset(
dataset_name=FLAGS.dataset,
split_name=FLAGS.train_split,
......@@ -454,21 +304,136 @@ def main(unused_argv):
max_scale_factor=FLAGS.max_scale_factor,
scale_factor_step_size=FLAGS.scale_factor_step_size,
model_variant=FLAGS.model_variant,
num_readers=2,
num_readers=4,
is_training=True,
should_shuffle=True,
should_repeat=True)
train_tensor, summary_op = _train_deeplab_model(
dataset.get_one_shot_iterator(), dataset.num_of_classes,
dataset.ignore_label)
# Soft placement allows placing on CPU ops without GPU implementation.
session_config = tf.ConfigProto(
allow_soft_placement=True, log_device_placement=False)
# Create the global step on the device storing the variables.
with tf.device(config.variables_device()):
global_step = tf.train.get_or_create_global_step()
# Define the model and create clones.
model_fn = _build_deeplab
model_args = (dataset.get_one_shot_iterator(), {
common.OUTPUT_TYPE: dataset.num_of_classes
}, dataset.ignore_label)
clones = model_deploy.create_clones(config, model_fn, args=model_args)
# Gather update_ops from the first clone. These contain, for example,
# the updates for the batch_norm variables created by model_fn.
first_clone_scope = config.clone_scope(0)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
# Gather initial summaries.
summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
# Add summaries for model variables.
for model_var in tf.model_variables():
summaries.add(tf.summary.histogram(model_var.op.name, model_var))
# Add summaries for images, labels, semantic predictions
if FLAGS.save_summaries_images:
summary_image = graph.get_tensor_by_name(
('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
summaries.add(
tf.summary.image('samples/%s' % common.IMAGE, summary_image))
first_clone_label = graph.get_tensor_by_name(
('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
# Scale up summary image pixel values for better visualization.
pixel_scaling = max(1, 255 // dataset.num_classes)
summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8)
summaries.add(
tf.summary.image('samples/%s' % common.LABEL, summary_label))
first_clone_output = graph.get_tensor_by_name(
('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)
summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
summaries.add(
tf.summary.image(
'samples/%s' % common.OUTPUT_TYPE, summary_predictions))
# Add summaries for losses.
for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))
# Build the optimizer based on the device specification.
with tf.device(config.optimizer_device()):
learning_rate = train_utils.get_model_learning_rate(
FLAGS.learning_policy,
FLAGS.base_learning_rate,
FLAGS.learning_rate_decay_step,
FLAGS.learning_rate_decay_factor,
FLAGS.training_number_of_steps,
FLAGS.learning_power,
FLAGS.slow_start_step,
FLAGS.slow_start_learning_rate,
decay_steps=FLAGS.decay_steps,
end_learning_rate=FLAGS.end_learning_rate)
summaries.add(tf.summary.scalar('learning_rate', learning_rate))
if FLAGS.optimizer == 'momentum':
optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
elif FLAGS.optimizer == 'adam':
optimizer = tf.train.AdamOptimizer(
learning_rate=FLAGS.adam_learning_rate, epsilon=FLAGS.adam_epsilon)
else:
raise ValueError('Unknown optimizer')
if FLAGS.quantize_delay_step >= 0:
if FLAGS.num_clones > 1:
raise ValueError('Quantization doesn\'t support multi-clone yet.')
contrib_quantize.create_training_graph(
quant_delay=FLAGS.quantize_delay_step)
startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
with tf.device(config.variables_device()):
total_loss, grads_and_vars = model_deploy.optimize_clones(
clones, optimizer)
total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
summaries.add(tf.summary.scalar('total_loss', total_loss))
# Modify the gradients for biases and last layer variables.
last_layers = model.get_extra_layer_scopes(
FLAGS.last_layers_contain_logits_only)
grad_mult = train_utils.get_model_gradient_multipliers(
last_layers, FLAGS.last_layer_gradient_multiplier)
if grad_mult:
grads_and_vars = slim.learning.multiply_gradients(
grads_and_vars, grad_mult)
# Create gradient update op.
grad_updates = optimizer.apply_gradients(
grads_and_vars, global_step=global_step)
update_ops.append(grad_updates)
update_op = tf.group(*update_ops)
with tf.control_dependencies([update_op]):
train_tensor = tf.identity(total_loss, name='train_op')
# Add the summaries from the first clone. These contain the summaries
# created by model_fn and either optimize_clones() or _gather_clone_loss().
summaries |= set(
tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
# Merge all summaries together.
summary_op = tf.summary.merge(list(summaries))
# Soft placement allows placing on CPU ops without GPU implementation.
session_config = tf.ConfigProto(
allow_soft_placement=True, log_device_placement=False)
# Start the training.
profile_dir = FLAGS.profile_logdir
if profile_dir is not None:
tf.gfile.MakeDirs(profile_dir)
with contrib_tfprof.ProfileContext(
enabled=profile_dir is not None, profile_dir=profile_dir):
init_fn = None
if FLAGS.tf_initial_checkpoint:
init_fn = train_utils.get_model_init_fn(
......@@ -478,33 +443,19 @@ def main(unused_argv):
last_layers,
ignore_missing_vars=True)
scaffold = tf.train.Scaffold(
slim.learning.train(
train_tensor,
logdir=FLAGS.train_logdir,
log_every_n_steps=FLAGS.log_steps,
master=FLAGS.master,
number_of_steps=FLAGS.training_number_of_steps,
is_chief=(FLAGS.task == 0),
session_config=session_config,
startup_delay_steps=startup_delay_steps,
init_fn=init_fn,
summary_op=summary_op,
)
stop_hook = tf.train.StopAtStepHook(
last_step=FLAGS.training_number_of_steps)
profile_dir = FLAGS.profile_logdir
if profile_dir is not None:
tf.gfile.MakeDirs(profile_dir)
with tf.contrib.tfprof.ProfileContext(
enabled=profile_dir is not None, profile_dir=profile_dir):
with tf.train.MonitoredTrainingSession(
master=FLAGS.master,
is_chief=(FLAGS.task == 0),
config=session_config,
scaffold=scaffold,
checkpoint_dir=FLAGS.train_logdir,
summary_dir=FLAGS.train_logdir,
log_step_count_steps=FLAGS.log_steps,
save_summaries_steps=FLAGS.save_summaries_secs,
save_checkpoint_secs=FLAGS.save_interval_secs,
hooks=[stop_hook]) as sess:
while not sess.should_stop():
sess.run([train_tensor])
save_summaries_secs=FLAGS.save_summaries_secs,
save_interval_secs=FLAGS.save_interval_secs)
if __name__ == '__main__':
......
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -15,15 +16,18 @@
"""Utility functions for training."""
import six
import tensorflow as tf
from tensorflow.contrib import framework as contrib_framework
from deeplab.core import preprocess_utils
from deeplab.core import utils
def _div_maybe_zero(total_loss, num_present):
"""Normalizes the total loss with the number of present pixels."""
return tf.to_float(num_present > 0) * tf.div(total_loss,
tf.maximum(1e-5, num_present))
return tf.to_float(num_present > 0) * tf.math.divide(
total_loss,
tf.maximum(1e-5, num_present))
def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
......@@ -34,6 +38,7 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
upsample_logits=True,
hard_example_mining_step=0,
top_k_percent_pixels=1.0,
gt_is_matting_map=False,
scope=None):
"""Adds softmax cross entropy loss for logits of each scale.
......@@ -43,7 +48,11 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
labels: Groundtruth labels with shape [batch, image_height, image_width, 1].
num_classes: Integer, number of target classes.
ignore_label: Integer, label to ignore.
loss_weight: Float, loss weight.
loss_weight: A float or a list of loss weights. If it is a float, it means
all the labels have the same weight. If it is a list of weights, then each
element in the list represents the weight for the label of its index, for
example, loss_weight = [0.1, 0.5] means the weight for label 0 is 0.1 and
the weight for label 1 is 0.5.
upsample_logits: Boolean, upsample logits or not.
hard_example_mining_step: An integer, the training step in which the hard
exampling mining kicks off. Note that we gradually reduce the mining
......@@ -54,14 +63,22 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
top_k_percent_pixels: A float, the value lies in [0.0, 1.0]. When its value
< 1.0, only compute the loss for the top k percent pixels (e.g., the top
20% pixels). This is useful for hard pixel mining.
gt_is_matting_map: If true, the groundtruth is a matting map of confidence
score. If false, the groundtruth is an integer valued class mask.
scope: String, the scope for the loss.
Raises:
ValueError: Label or logits is None.
ValueError: Label or logits is None, or groundtruth is matting map while
label is not floating value.
"""
if labels is None:
raise ValueError('No label for softmax cross entropy loss.')
# If input groundtruth is a matting map of confidence, check if the input
# labels are floating point values.
if gt_is_matting_map and not labels.dtype.is_floating:
raise ValueError('Labels must be floats if groundtruth is a matting map.')
for scale, logits in six.iteritems(scales_to_logits):
loss_scope = None
if scope:
......@@ -76,36 +93,70 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
scaled_labels = labels
else:
# Label is downsampled to the same size as logits.
# When gt_is_matting_map = true, label downsampling with nearest neighbor
# method may introduce artifacts. However, to avoid ignore_label from
# being interpolated with other labels, we still perform nearest neighbor
# interpolation.
# TODO(huizhongc): Change to bilinear interpolation by processing padded
# and non-padded label separately.
if gt_is_matting_map:
tf.logging.warning(
'Label downsampling with nearest neighbor may introduce artifacts.')
scaled_labels = tf.image.resize_nearest_neighbor(
labels,
preprocess_utils.resolve_shape(logits, 4)[1:3],
align_corners=True)
scaled_labels = tf.reshape(scaled_labels, shape=[-1])
not_ignore_mask = tf.to_float(tf.not_equal(scaled_labels,
ignore_label)) * loss_weight
one_hot_labels = tf.one_hot(
scaled_labels, num_classes, on_value=1.0, off_value=0.0)
if top_k_percent_pixels == 1.0:
# Compute the loss for all pixels.
tf.losses.softmax_cross_entropy(
one_hot_labels,
tf.reshape(logits, shape=[-1, num_classes]),
weights=not_ignore_mask,
scope=loss_scope)
weights = utils.get_label_weight_mask(
scaled_labels, ignore_label, num_classes, label_weights=loss_weight)
# Dimension of keep_mask is equal to the total number of pixels.
keep_mask = tf.cast(
tf.not_equal(scaled_labels, ignore_label), dtype=tf.float32)
train_labels = None
logits = tf.reshape(logits, shape=[-1, num_classes])
if gt_is_matting_map:
# When the groundtruth is integer label mask, we can assign class
# dependent label weights to the loss. When the groundtruth is image
# matting confidence, we do not apply class-dependent label weight (i.e.,
# label_weight = 1.0).
if loss_weight != 1.0:
raise ValueError(
'loss_weight must equal to 1 if groundtruth is matting map.')
# Assign label value 0 to ignore pixels. The exact label value of ignore
# pixel does not matter, because those ignore_value pixel losses will be
# multiplied to 0 weight.
train_labels = scaled_labels * keep_mask
train_labels = tf.expand_dims(train_labels, 1)
train_labels = tf.concat([1 - train_labels, train_labels], axis=1)
else:
logits = tf.reshape(logits, shape=[-1, num_classes])
weights = not_ignore_mask
with tf.name_scope(loss_scope, 'softmax_hard_example_mining',
[logits, one_hot_labels, weights]):
one_hot_labels = tf.stop_gradient(
one_hot_labels, name='labels_stop_gradient')
pixel_losses = tf.nn.softmax_cross_entropy_with_logits_v2(
labels=one_hot_labels,
logits=logits,
name='pixel_losses')
weighted_pixel_losses = tf.multiply(pixel_losses, weights)
train_labels = tf.one_hot(
scaled_labels, num_classes, on_value=1.0, off_value=0.0)
default_loss_scope = ('softmax_all_pixel_loss'
if top_k_percent_pixels == 1.0 else
'softmax_hard_example_mining')
with tf.name_scope(loss_scope, default_loss_scope,
[logits, train_labels, weights]):
# Compute the loss for all pixels.
pixel_losses = tf.nn.softmax_cross_entropy_with_logits_v2(
labels=tf.stop_gradient(
train_labels, name='train_labels_stop_gradient'),
logits=logits,
name='pixel_losses')
weighted_pixel_losses = tf.multiply(pixel_losses, weights)
if top_k_percent_pixels == 1.0:
total_loss = tf.reduce_sum(weighted_pixel_losses)
num_present = tf.reduce_sum(keep_mask)
loss = _div_maybe_zero(total_loss, num_present)
tf.losses.add_loss(loss)
else:
num_pixels = tf.to_float(tf.shape(logits)[0])
# Compute the top_k_percent pixels based on current training step.
if hard_example_mining_step == 0:
......@@ -160,11 +211,11 @@ def get_model_init_fn(train_logdir,
if not initialize_last_layer:
exclude_list.extend(last_layers)
variables_to_restore = tf.contrib.framework.get_variables_to_restore(
variables_to_restore = contrib_framework.get_variables_to_restore(
exclude=exclude_list)
if variables_to_restore:
init_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
init_op, init_feed_dict = contrib_framework.assign_from_checkpoint(
tf_initial_checkpoint,
variables_to_restore,
ignore_missing_vars=ignore_missing_vars)
......@@ -222,7 +273,11 @@ def get_model_learning_rate(learning_policy,
learning_power,
slow_start_step,
slow_start_learning_rate,
slow_start_burnin_type='none'):
slow_start_burnin_type='none',
decay_steps=0.0,
end_learning_rate=0.0,
boundaries=None,
boundary_learning_rates=None):
"""Gets model's learning rate.
Computes the model's learning rate for different learning policy.
......@@ -249,19 +304,28 @@ def get_model_learning_rate(learning_policy,
`none` which means no burnin or `linear` which means the learning rate
increases linearly from slow_start_learning_rate and reaches
base_learning_rate after slow_start_steps.
decay_steps: Float, `decay_steps` for polynomial learning rate.
end_learning_rate: Float, `end_learning_rate` for polynomial learning rate.
boundaries: A list of `Tensor`s or `int`s or `float`s with strictly
increasing entries.
boundary_learning_rates: A list of `Tensor`s or `float`s or `int`s that
specifies the values for the intervals defined by `boundaries`. It should
have one more element than `boundaries`, and all elements should have the
same type.
Returns:
Learning rate for the specified learning policy.
Raises:
ValueError: If learning policy or slow start burnin type is not recognized.
ValueError: If `boundaries` and `boundary_learning_rates` are not set for
multi_steps learning rate decay.
"""
global_step = tf.train.get_or_create_global_step()
adjusted_global_step = global_step
if slow_start_burnin_type != 'none':
adjusted_global_step -= slow_start_step
adjusted_global_step = tf.maximum(global_step - slow_start_step, 0)
if decay_steps == 0.0:
tf.logging.info('Setting decay_steps to total training steps.')
decay_steps = training_number_of_steps - slow_start_step
if learning_policy == 'step':
learning_rate = tf.train.exponential_decay(
base_learning_rate,
......@@ -273,9 +337,22 @@ def get_model_learning_rate(learning_policy,
learning_rate = tf.train.polynomial_decay(
base_learning_rate,
adjusted_global_step,
training_number_of_steps,
end_learning_rate=0,
decay_steps=decay_steps,
end_learning_rate=end_learning_rate,
power=learning_power)
elif learning_policy == 'cosine':
learning_rate = tf.train.cosine_decay(
base_learning_rate,
adjusted_global_step,
training_number_of_steps - slow_start_step)
elif learning_policy == 'multi_steps':
if boundaries is None or boundary_learning_rates is None:
raise ValueError('Must set `boundaries` and `boundary_learning_rates` '
'for multi_steps learning rate decay.')
learning_rate = tf.train.piecewise_constant_decay(
adjusted_global_step,
boundaries,
boundary_learning_rates)
else:
raise ValueError('Unknown learning policy.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment