Commit 2c680af3 authored by Sergio Guadarrama's avatar Sergio Guadarrama Committed by pkulzc
Browse files

Merged commit includes the following changes:

223150784  by Sergio Guadarrama:

    Allow using batch norm scale parameters for Inception models.

--
221391590  by Sergio Guadarrama:

    Add support for group normalization in the object detection API. Just adding MobileNet-v1 SSD currently. This may serve as a road map for other models that wish to support group normalization as an option.

--
221342582  by Sergio Guadarrama:

    Internal change

220817084  by Sergio Guadarrama:

    Internal change

216005108  by Sergio Guadarrama:

    Introduce hparam `use_bounded_activation` for NASNet. The hparam decides whether to use
    1. bounded activation
    2. clip_by_value for the add operands and bounded activation after add operator.
    3. bounded activation before 'none' and 'pooling' branch
    The restriction on the tensor value range makes it compatible with quantized inference.

--

PiperOrigin-RevId: 223150784
parent 5324fc66
......@@ -63,7 +63,8 @@ def cyclegan_arg_scope(instance_norm_center=True,
return sc
def cyclegan_upsample(net, num_outputs, stride, method='conv2d_transpose'):
def cyclegan_upsample(net, num_outputs, stride, method='conv2d_transpose',
pad_mode='REFLECT', align_corners=False):
"""Upsamples the given inputs.
Args:
......@@ -75,6 +76,10 @@ def cyclegan_upsample(net, num_outputs, stride, method='conv2d_transpose'):
times the input size.
method: The upsampling method: 'nn_upsample_conv', 'bilinear_upsample_conv',
or 'conv2d_transpose'.
pad_mode: mode for tf.pad, one of "CONSTANT", "REFLECT", or "SYMMETRIC".
align_corners: option for method, 'bilinear_upsample_conv'. If true, the
centers of the 4 corner pixels of the input and output tensors are
aligned, preserving the values at the corner pixels.
Returns:
A Tensor which was upsampled using the specified method.
......@@ -95,12 +100,13 @@ def cyclegan_upsample(net, num_outputs, stride, method='conv2d_transpose'):
if method == 'nn_upsample_conv':
net = tf.image.resize_nearest_neighbor(
net, [stride[0] * height, stride[1] * width])
net = tf.pad(net, spatial_pad_1, 'REFLECT')
net = tf.pad(net, spatial_pad_1, pad_mode)
net = layers.conv2d(net, num_outputs, kernel_size=[3, 3], padding='valid')
elif method == 'bilinear_upsample_conv':
net = tf.image.resize_bilinear(
net, [stride[0] * height, stride[1] * width])
net = tf.pad(net, spatial_pad_1, 'REFLECT')
net, [stride[0] * height, stride[1] * width],
align_corners=align_corners)
net = tf.pad(net, spatial_pad_1, pad_mode)
net = layers.conv2d(net, num_outputs, kernel_size=[3, 3], padding='valid')
elif method == 'conv2d_transpose':
# This corrects 1 pixel offset for images with even width and height.
......@@ -111,7 +117,7 @@ def cyclegan_upsample(net, num_outputs, stride, method='conv2d_transpose'):
net, num_outputs, kernel_size=[3, 3], stride=stride, padding='valid')
net = net[:, 1:, 1:, :]
else:
raise ValueError('Unknown method: [%s]', method)
raise ValueError('Unknown method: [%s]' % method)
return net
......
......@@ -370,7 +370,8 @@ def inception_resnet_v2_arg_scope(
batch_norm_decay=0.9997,
batch_norm_epsilon=0.001,
activation_fn=tf.nn.relu,
batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS):
batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS,
batch_norm_scale=False):
"""Returns the scope with the default parameters for inception_resnet_v2.
Args:
......@@ -380,6 +381,8 @@ def inception_resnet_v2_arg_scope(
activation_fn: Activation function for conv2d.
batch_norm_updates_collections: Collection for the update ops for
batch norm.
batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
activations in the batch normalization layer.
Returns:
a arg_scope with the parameters needed for inception_resnet_v2.
......@@ -394,6 +397,7 @@ def inception_resnet_v2_arg_scope(
'epsilon': batch_norm_epsilon,
'updates_collections': batch_norm_updates_collections,
'fused': None, # Use fused batch norm if possible.
'scale': batch_norm_scale,
}
# Set activation_fn and parameters for batch_norm.
with slim.arg_scope([slim.conv2d], activation_fn=activation_fn,
......
......@@ -306,6 +306,29 @@ class InceptionTest(tf.test.TestCase):
output = sess.run(predictions)
self.assertEquals(output.shape, (eval_batch_size,))
def testNoBatchNormScaleByDefault(self):
height, width = 299, 299
num_classes = 1000
inputs = tf.placeholder(tf.float32, (1, height, width, 3))
with tf.contrib.slim.arg_scope(inception.inception_resnet_v2_arg_scope()):
inception.inception_resnet_v2(inputs, num_classes, is_training=False)
self.assertEqual(tf.global_variables('.*/BatchNorm/gamma:0$'), [])
def testBatchNormScale(self):
height, width = 299, 299
num_classes = 1000
inputs = tf.placeholder(tf.float32, (1, height, width, 3))
with tf.contrib.slim.arg_scope(
inception.inception_resnet_v2_arg_scope(batch_norm_scale=True)):
inception.inception_resnet_v2(inputs, num_classes, is_training=False)
gamma_names = set(
v.op.name for v in tf.global_variables('.*/BatchNorm/gamma:0$'))
self.assertGreater(len(gamma_names), 0)
for v in tf.global_variables('.*/BatchNorm/moving_mean:0$'):
self.assertIn(v.op.name[:-len('moving_mean')] + 'gamma', gamma_names)
if __name__ == '__main__':
tf.test.main()
......@@ -34,7 +34,8 @@ def inception_arg_scope(weight_decay=0.00004,
batch_norm_decay=0.9997,
batch_norm_epsilon=0.001,
activation_fn=tf.nn.relu,
batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS):
batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS,
batch_norm_scale=False):
"""Defines the default arg scope for inception models.
Args:
......@@ -46,6 +47,8 @@ def inception_arg_scope(weight_decay=0.00004,
activation_fn: Activation function for conv2d.
batch_norm_updates_collections: Collection for the update ops for
batch norm.
batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
activations in the batch normalization layer.
Returns:
An `arg_scope` to use for the inception models.
......@@ -59,6 +62,7 @@ def inception_arg_scope(weight_decay=0.00004,
'updates_collections': batch_norm_updates_collections,
# use fused batch norm if possible.
'fused': None,
'scale': batch_norm_scale,
}
if use_batch_norm:
normalizer_fn = slim.batch_norm
......
......@@ -237,6 +237,29 @@ class InceptionV1Test(tf.test.TestCase):
logits_out = sess.run(logits)
self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes])
def testNoBatchNormScaleByDefault(self):
height, width = 224, 224
num_classes = 1000
inputs = tf.placeholder(tf.float32, (1, height, width, 3))
with slim.arg_scope(inception.inception_v1_arg_scope()):
inception.inception_v1(inputs, num_classes, is_training=False)
self.assertEqual(tf.global_variables('.*/BatchNorm/gamma:0$'), [])
def testBatchNormScale(self):
height, width = 224, 224
num_classes = 1000
inputs = tf.placeholder(tf.float32, (1, height, width, 3))
with slim.arg_scope(
inception.inception_v1_arg_scope(batch_norm_scale=True)):
inception.inception_v1(inputs, num_classes, is_training=False)
gamma_names = set(
v.op.name for v in tf.global_variables('.*/BatchNorm/gamma:0$'))
self.assertGreater(len(gamma_names), 0)
for v in tf.global_variables('.*/BatchNorm/moving_mean:0$'):
self.assertIn(v.op.name[:-len('moving_mean')] + 'gamma', gamma_names)
if __name__ == '__main__':
tf.test.main()
......@@ -351,6 +351,29 @@ class InceptionV2Test(tf.test.TestCase):
logits_out = sess.run(logits)
self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes])
def testNoBatchNormScaleByDefault(self):
height, width = 224, 224
num_classes = 1000
inputs = tf.placeholder(tf.float32, (1, height, width, 3))
with slim.arg_scope(inception.inception_v2_arg_scope()):
inception.inception_v2(inputs, num_classes, is_training=False)
self.assertEqual(tf.global_variables('.*/BatchNorm/gamma:0$'), [])
def testBatchNormScale(self):
height, width = 224, 224
num_classes = 1000
inputs = tf.placeholder(tf.float32, (1, height, width, 3))
with slim.arg_scope(
inception.inception_v2_arg_scope(batch_norm_scale=True)):
inception.inception_v2(inputs, num_classes, is_training=False)
gamma_names = set(
v.op.name for v in tf.global_variables('.*/BatchNorm/gamma:0$'))
self.assertGreater(len(gamma_names), 0)
for v in tf.global_variables('.*/BatchNorm/moving_mean:0$'):
self.assertIn(v.op.name[:-len('moving_mean')] + 'gamma', gamma_names)
if __name__ == '__main__':
tf.test.main()
......@@ -318,6 +318,29 @@ class InceptionV3Test(tf.test.TestCase):
logits_out = sess.run(logits)
self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes])
def testNoBatchNormScaleByDefault(self):
height, width = 299, 299
num_classes = 1000
inputs = tf.placeholder(tf.float32, (1, height, width, 3))
with slim.arg_scope(inception.inception_v3_arg_scope()):
inception.inception_v3(inputs, num_classes, is_training=False)
self.assertEqual(tf.global_variables('.*/BatchNorm/gamma:0$'), [])
def testBatchNormScale(self):
height, width = 299, 299
num_classes = 1000
inputs = tf.placeholder(tf.float32, (1, height, width, 3))
with slim.arg_scope(
inception.inception_v3_arg_scope(batch_norm_scale=True)):
inception.inception_v3(inputs, num_classes, is_training=False)
gamma_names = set(
v.op.name for v in tf.global_variables('.*/BatchNorm/gamma:0$'))
self.assertGreater(len(gamma_names), 0)
for v in tf.global_variables('.*/BatchNorm/moving_mean:0$'):
self.assertIn(v.op.name[:-len('moving_mean')] + 'gamma', gamma_names)
if __name__ == '__main__':
tf.test.main()
......@@ -255,6 +255,29 @@ class InceptionTest(tf.test.TestCase):
output = sess.run(predictions)
self.assertEquals(output.shape, (eval_batch_size,))
def testNoBatchNormScaleByDefault(self):
height, width = 299, 299
num_classes = 1000
inputs = tf.placeholder(tf.float32, (1, height, width, 3))
with tf.contrib.slim.arg_scope(inception.inception_v4_arg_scope()):
inception.inception_v4(inputs, num_classes, is_training=False)
self.assertEqual(tf.global_variables('.*/BatchNorm/gamma:0$'), [])
def testBatchNormScale(self):
height, width = 299, 299
num_classes = 1000
inputs = tf.placeholder(tf.float32, (1, height, width, 3))
with tf.contrib.slim.arg_scope(
inception.inception_v4_arg_scope(batch_norm_scale=True)):
inception.inception_v4(inputs, num_classes, is_training=False)
gamma_names = set(
v.op.name for v in tf.global_variables('.*/BatchNorm/gamma:0$'))
self.assertGreater(len(gamma_names), 0)
for v in tf.global_variables('.*/BatchNorm/moving_mean:0$'):
self.assertIn(v.op.name[:-len('moving_mean')] + 'gamma', gamma_names)
if __name__ == '__main__':
tf.test.main()
......@@ -263,7 +263,6 @@ def mobilenet_v1_base(inputs,
net = _fixed_padding(net, conv_def.kernel)
net = slim.conv2d(net, depth(conv_def.depth), conv_def.kernel,
stride=conv_def.stride,
normalizer_fn=slim.batch_norm,
scope=end_point)
end_points[end_point] = net
if end_point == final_endpoint:
......@@ -280,7 +279,6 @@ def mobilenet_v1_base(inputs,
depth_multiplier=1,
stride=layer_stride,
rate=layer_rate,
normalizer_fn=slim.batch_norm,
scope=end_point)
end_points[end_point] = net
......@@ -291,7 +289,6 @@ def mobilenet_v1_base(inputs,
net = slim.conv2d(net, depth(conv_def.depth), [1, 1],
stride=1,
normalizer_fn=slim.batch_norm,
scope=end_point)
end_points[end_point] = net
......@@ -432,7 +429,8 @@ def mobilenet_v1_arg_scope(
regularize_depthwise=False,
batch_norm_decay=0.9997,
batch_norm_epsilon=0.001,
batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS):
batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS,
normalizer_fn=slim.batch_norm):
"""Defines the default MobilenetV1 arg scope.
Args:
......@@ -446,6 +444,7 @@ def mobilenet_v1_arg_scope(
in batch norm.
batch_norm_updates_collections: Collection for the update ops for
batch norm.
normalizer_fn: Normalization function to apply after convolution.
Returns:
An `arg_scope` to use for the mobilenet v1 model.
......@@ -469,7 +468,7 @@ def mobilenet_v1_arg_scope(
depthwise_regularizer = None
with slim.arg_scope([slim.conv2d, slim.separable_conv2d],
weights_initializer=weights_init,
activation_fn=tf.nn.relu6, normalizer_fn=slim.batch_norm):
activation_fn=tf.nn.relu6, normalizer_fn=normalizer_fn):
with slim.arg_scope([slim.batch_norm], **batch_norm_params):
with slim.arg_scope([slim.conv2d], weights_regularizer=regularizer):
with slim.arg_scope([slim.separable_conv2d],
......
......@@ -52,6 +52,7 @@ def cifar_config():
# This is used for the drop path probabilities since it needs to increase
# the drop out probability over the course of training.
total_training_steps=937500,
use_bounded_activation=False,
)
......@@ -78,6 +79,7 @@ def large_imagenet_config():
data_format='NHWC',
skip_reduction_layer_input=1,
total_training_steps=250000,
use_bounded_activation=False,
)
......@@ -104,6 +106,7 @@ def mobile_imagenet_config():
data_format='NHWC',
skip_reduction_layer_input=0,
total_training_steps=250000,
use_bounded_activation=False,
)
......@@ -223,6 +226,7 @@ def nasnet_large_arg_scope(weight_decay=5e-5,
def _build_aux_head(net, end_points, num_classes, hparams, scope):
"""Auxiliary head used for all models across all datasets."""
activation_fn = tf.nn.relu6 if hparams.use_bounded_activation else tf.nn.relu
with tf.variable_scope(scope):
aux_logits = tf.identity(net)
with tf.variable_scope('aux_logits'):
......@@ -230,7 +234,7 @@ def _build_aux_head(net, end_points, num_classes, hparams, scope):
aux_logits, [5, 5], stride=3, padding='VALID')
aux_logits = slim.conv2d(aux_logits, 128, [1, 1], scope='proj')
aux_logits = slim.batch_norm(aux_logits, scope='aux_bn0')
aux_logits = tf.nn.relu(aux_logits)
aux_logits = activation_fn(aux_logits)
# Shape of feature map before the final layer.
shape = aux_logits.shape
if hparams.data_format == 'NHWC':
......@@ -239,7 +243,7 @@ def _build_aux_head(net, end_points, num_classes, hparams, scope):
shape = shape[2:4]
aux_logits = slim.conv2d(aux_logits, 768, shape, padding='VALID')
aux_logits = slim.batch_norm(aux_logits, scope='aux_bn1')
aux_logits = tf.nn.relu(aux_logits)
aux_logits = activation_fn(aux_logits)
aux_logits = tf.contrib.layers.flatten(aux_logits)
aux_logits = slim.fully_connected(aux_logits, num_classes)
end_points['AuxLogits'] = aux_logits
......@@ -306,10 +310,12 @@ def build_nasnet_cifar(images, num_classes,
normal_cell = nasnet_utils.NasNetANormalCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps)
total_num_cells, hparams.total_training_steps,
hparams.use_bounded_activation)
reduction_cell = nasnet_utils.NasNetAReductionCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps)
total_num_cells, hparams.total_training_steps,
hparams.use_bounded_activation)
with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
is_training=is_training):
with arg_scope([slim.avg_pool2d,
......@@ -358,10 +364,12 @@ def build_nasnet_mobile(images, num_classes,
normal_cell = nasnet_utils.NasNetANormalCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps)
total_num_cells, hparams.total_training_steps,
hparams.use_bounded_activation)
reduction_cell = nasnet_utils.NasNetAReductionCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps)
total_num_cells, hparams.total_training_steps,
hparams.use_bounded_activation)
with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
is_training=is_training):
with arg_scope([slim.avg_pool2d,
......@@ -411,10 +419,12 @@ def build_nasnet_large(images, num_classes,
normal_cell = nasnet_utils.NasNetANormalCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps)
total_num_cells, hparams.total_training_steps,
hparams.use_bounded_activation)
reduction_cell = nasnet_utils.NasNetAReductionCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps)
total_num_cells, hparams.total_training_steps,
hparams.use_bounded_activation)
with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
is_training=is_training):
with arg_scope([slim.avg_pool2d,
......@@ -478,6 +488,7 @@ def _build_nasnet_base(images,
filter_scaling = 1.0
# true_cell_num accounts for the stem cells
true_cell_num = 2 if stem_type == 'imagenet' else 0
activation_fn = tf.nn.relu6 if hparams.use_bounded_activation else tf.nn.relu
for cell_num in range(hparams.num_cells):
stride = 1
if hparams.skip_reduction_layer_input:
......@@ -513,14 +524,14 @@ def _build_nasnet_base(images,
true_cell_num += 1
if (hparams.use_aux_head and cell_num in aux_head_cell_idxes and
num_classes and is_training):
aux_net = tf.nn.relu(net)
aux_net = activation_fn(net)
_build_aux_head(aux_net, end_points, num_classes, hparams,
scope='aux_{}'.format(cell_num))
cell_outputs.append(net)
# Final softmax layer
with tf.variable_scope('final_layer'):
net = tf.nn.relu(net)
net = activation_fn(net)
net = nasnet_utils.global_avg_pool(net)
if add_and_check_endpoint('global_pool', net) or not num_classes:
return net, end_points
......
......@@ -390,5 +390,21 @@ class NASNetTest(tf.test.TestCase):
self.assertListEqual(predictions.get_shape().as_list(),
[batch_size, num_classes])
def testUseBoundedAcitvationCifarModel(self):
batch_size = 1
height, width = 32, 32
num_classes = 10
for use_bounded_activation in (True, False):
tf.reset_default_graph()
inputs = tf.random_uniform((batch_size, height, width, 3))
config = nasnet.cifar_config()
config.set_hparam('use_bounded_activation', use_bounded_activation)
with slim.arg_scope(nasnet.nasnet_cifar_arg_scope()):
_, _ = nasnet.build_nasnet_cifar(
inputs, num_classes, config=config)
for node in tf.get_default_graph().as_graph_def().node:
if node.op.startswith('Relu'):
self.assertEqual(node.op == 'Relu6', use_bounded_activation)
if __name__ == '__main__':
tf.test.main()
......@@ -40,6 +40,9 @@ slim = tf.contrib.slim
DATA_FORMAT_NCHW = 'NCHW'
DATA_FORMAT_NHWC = 'NHWC'
INVALID = 'null'
# The cap for tf.clip_by_value, it's hinted from the activation distribution
# that the majority of activation values are in the range [-6, 6].
CLIP_BY_VALUE_CAP = 6
def calc_reduction_layers(num_cells, num_reduction_layers):
......@@ -172,11 +175,13 @@ def _operation_to_info(operation):
return num_layers, filter_shape
def _stacked_separable_conv(net, stride, operation, filter_size):
def _stacked_separable_conv(net, stride, operation, filter_size,
use_bounded_activation):
"""Takes in an operations and parses it to the correct sep operation."""
num_layers, kernel_size = _operation_to_info(operation)
activation_fn = tf.nn.relu6 if use_bounded_activation else tf.nn.relu
for layer_num in range(num_layers - 1):
net = tf.nn.relu(net)
net = activation_fn(net)
net = slim.separable_conv2d(
net,
filter_size,
......@@ -187,7 +192,7 @@ def _stacked_separable_conv(net, stride, operation, filter_size):
net = slim.batch_norm(
net, scope='bn_sep_{0}x{0}_{1}'.format(kernel_size, layer_num + 1))
stride = 1
net = tf.nn.relu(net)
net = activation_fn(net)
net = slim.separable_conv2d(
net,
filter_size,
......@@ -223,10 +228,12 @@ def _operation_to_pooling_info(operation):
return pooling_type, pooling_shape
def _pooling(net, stride, operation):
def _pooling(net, stride, operation, use_bounded_activation):
"""Parses operation and performs the correct pooling operation on net."""
padding = 'SAME'
pooling_type, pooling_shape = _operation_to_pooling_info(operation)
if use_bounded_activation:
net = tf.nn.relu6(net)
if pooling_type == 'avg':
net = slim.avg_pool2d(net, pooling_shape, stride=stride, padding=padding)
elif pooling_type == 'max':
......@@ -248,11 +255,13 @@ class NasNetABaseCell(object):
should be concatenated together.
hiddenstate_indices: Determines what hiddenstates should be combined
together with the specified operations to create the NASNet cell.
use_bounded_activation: Whether or not to use bounded activations. Bounded
activations better lend themselves to quantized inference.
"""
def __init__(self, num_conv_filters, operations, used_hiddenstates,
hiddenstate_indices, drop_path_keep_prob, total_num_cells,
total_training_steps):
total_training_steps, use_bounded_activation=False):
self._num_conv_filters = num_conv_filters
self._operations = operations
self._used_hiddenstates = used_hiddenstates
......@@ -260,6 +269,7 @@ class NasNetABaseCell(object):
self._drop_path_keep_prob = drop_path_keep_prob
self._total_num_cells = total_num_cells
self._total_training_steps = total_training_steps
self._use_bounded_activation = use_bounded_activation
def _reduce_prev_layer(self, prev_layer, curr_layer):
"""Matches dimension of prev_layer to the curr_layer."""
......@@ -270,12 +280,13 @@ class NasNetABaseCell(object):
prev_num_filters = get_channel_dim(prev_layer.shape)
curr_filter_shape = int(curr_layer.shape[2])
prev_filter_shape = int(prev_layer.shape[2])
activation_fn = tf.nn.relu6 if self._use_bounded_activation else tf.nn.relu
if curr_filter_shape != prev_filter_shape:
prev_layer = tf.nn.relu(prev_layer)
prev_layer = activation_fn(prev_layer)
prev_layer = factorized_reduction(
prev_layer, curr_num_filters, stride=2)
elif curr_num_filters != prev_num_filters:
prev_layer = tf.nn.relu(prev_layer)
prev_layer = activation_fn(prev_layer)
prev_layer = slim.conv2d(
prev_layer, curr_num_filters, 1, scope='prev_1x1')
prev_layer = slim.batch_norm(prev_layer, scope='prev_bn')
......@@ -288,14 +299,11 @@ class NasNetABaseCell(object):
# Check to be sure prev layer stuff is setup correctly
prev_layer = self._reduce_prev_layer(prev_layer, net)
net = tf.nn.relu(net)
net = tf.nn.relu6(net) if self._use_bounded_activation else tf.nn.relu(net)
net = slim.conv2d(net, num_filters, 1, scope='1x1')
net = slim.batch_norm(net, scope='beginning_bn')
split_axis = get_channel_index()
net = tf.split(axis=split_axis, num_or_size_splits=1, value=net)
for split in net:
assert int(split.shape[split_axis] == int(self._num_conv_filters *
self._filter_scaling))
# num_or_size_splits=1
net = [net]
net.append(prev_layer)
return net
......@@ -335,6 +343,8 @@ class NasNetABaseCell(object):
# Combine hidden states using 'add'.
with tf.variable_scope('combine'):
h = h1 + h2
if self._use_bounded_activation:
h = tf.nn.relu6(h)
# Add hiddenstate to the list of hiddenstates we can choose from
net.append(h)
......@@ -353,18 +363,28 @@ class NasNetABaseCell(object):
input_filters = get_channel_dim(net.shape)
filter_size = self._filter_size
if 'separable' in operation:
net = _stacked_separable_conv(net, stride, operation, filter_size)
net = _stacked_separable_conv(net, stride, operation, filter_size,
self._use_bounded_activation)
if self._use_bounded_activation:
net = tf.clip_by_value(net, -CLIP_BY_VALUE_CAP, CLIP_BY_VALUE_CAP)
elif operation in ['none']:
if self._use_bounded_activation:
net = tf.nn.relu6(net)
# Check if a stride is needed, then use a strided 1x1 here
if stride > 1 or (input_filters != filter_size):
if not self._use_bounded_activation:
net = tf.nn.relu(net)
net = slim.conv2d(net, filter_size, 1, stride=stride, scope='1x1')
net = slim.batch_norm(net, scope='bn_1')
if self._use_bounded_activation:
net = tf.clip_by_value(net, -CLIP_BY_VALUE_CAP, CLIP_BY_VALUE_CAP)
elif 'pool' in operation:
net = _pooling(net, stride, operation)
net = _pooling(net, stride, operation, self._use_bounded_activation)
if input_filters != filter_size:
net = slim.conv2d(net, filter_size, 1, stride=1, scope='1x1')
net = slim.batch_norm(net, scope='bn_1')
if self._use_bounded_activation:
net = tf.clip_by_value(net, -CLIP_BY_VALUE_CAP, CLIP_BY_VALUE_CAP)
else:
raise ValueError('Unimplemented operation', operation)
......@@ -456,7 +476,7 @@ class NasNetANormalCell(NasNetABaseCell):
"""NASNetA Normal Cell."""
def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells,
total_training_steps):
total_training_steps, use_bounded_activation=False):
operations = ['separable_5x5_2',
'separable_3x3_2',
'separable_5x5_2',
......@@ -474,14 +494,15 @@ class NasNetANormalCell(NasNetABaseCell):
hiddenstate_indices,
drop_path_keep_prob,
total_num_cells,
total_training_steps)
total_training_steps,
use_bounded_activation)
class NasNetAReductionCell(NasNetABaseCell):
"""NASNetA Reduction Cell."""
def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells,
total_training_steps):
total_training_steps, use_bounded_activation=False):
operations = ['separable_5x5_2',
'separable_7x7_2',
'max_pool_3x3',
......@@ -499,4 +520,5 @@ class NasNetAReductionCell(NasNetABaseCell):
hiddenstate_indices,
drop_path_keep_prob,
total_num_cells,
total_training_steps)
total_training_steps,
use_bounded_activation)
......@@ -45,6 +45,7 @@ def large_imagenet_config():
data_format='NHWC',
skip_reduction_layer_input=1,
total_training_steps=250000,
use_bounded_activation=False,
)
......@@ -62,6 +63,7 @@ def mobile_imagenet_config():
data_format='NHWC',
skip_reduction_layer_input=1,
total_training_steps=250000,
use_bounded_activation=False,
)
......@@ -114,6 +116,7 @@ def _build_pnasnet_base(images,
filter_scaling = 1.0
# true_cell_num accounts for the stem cells
true_cell_num = 2
activation_fn = tf.nn.relu6 if hparams.use_bounded_activation else tf.nn.relu
for cell_num in range(hparams.num_cells):
is_reduction = cell_num in reduction_indices
stride = 2 if is_reduction else 1
......@@ -134,7 +137,7 @@ def _build_pnasnet_base(images,
if (hparams.use_aux_head and cell_num in aux_head_cell_idxes and
num_classes and is_training):
aux_net = tf.nn.relu(net)
aux_net = activation_fn(net)
# pylint: disable=protected-access
nasnet._build_aux_head(aux_net, end_points, num_classes, hparams,
scope='aux_{}'.format(cell_num))
......@@ -142,7 +145,7 @@ def _build_pnasnet_base(images,
# Final softmax layer
with tf.variable_scope('final_layer'):
net = tf.nn.relu(net)
net = activation_fn(net)
net = nasnet_utils.global_avg_pool(net)
if add_and_check_endpoint('global_pool', net) or not num_classes:
return net, end_points
......@@ -184,7 +187,8 @@ def build_pnasnet_large(images,
normal_cell = PNasNetNormalCell(hparams.num_conv_filters,
hparams.drop_path_keep_prob, total_num_cells,
hparams.total_training_steps)
hparams.total_training_steps,
hparams.use_bounded_activation)
with arg_scope(
[slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
is_training=is_training):
......@@ -231,7 +235,8 @@ def build_pnasnet_mobile(images,
normal_cell = PNasNetNormalCell(hparams.num_conv_filters,
hparams.drop_path_keep_prob, total_num_cells,
hparams.total_training_steps)
hparams.total_training_steps,
hparams.use_bounded_activation)
with arg_scope(
[slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
is_training=is_training):
......@@ -259,7 +264,7 @@ class PNasNetNormalCell(nasnet_utils.NasNetABaseCell):
"""PNASNet Normal Cell."""
def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells,
total_training_steps):
total_training_steps, use_bounded_activation=False):
# Configuration for the PNASNet-5 model.
operations = [
'separable_5x5_2', 'max_pool_3x3', 'separable_7x7_2', 'max_pool_3x3',
......@@ -271,4 +276,5 @@ class PNasNetNormalCell(nasnet_utils.NasNetABaseCell):
super(PNasNetNormalCell, self).__init__(
num_conv_filters, operations, used_hiddenstates, hiddenstate_indices,
drop_path_keep_prob, total_num_cells, total_training_steps)
drop_path_keep_prob, total_num_cells, total_training_steps,
use_bounded_activation)
......@@ -236,6 +236,21 @@ class PNASNetTest(tf.test.TestCase):
self.assertListEqual(end_points['Stem'].shape.as_list(),
[batch_size, 135, 28, 28])
def testUseBoundedAcitvationMobileModel(self):
batch_size = 1
height, width = 224, 224
num_classes = 1000
for use_bounded_activation in (True, False):
tf.reset_default_graph()
inputs = tf.random_uniform((batch_size, height, width, 3))
config = pnasnet.mobile_imagenet_config()
config.set_hparam('use_bounded_activation', use_bounded_activation)
with slim.arg_scope(pnasnet.pnasnet_mobile_arg_scope()):
_, _ = pnasnet.build_pnasnet_mobile(
inputs, num_classes, config=config)
for node in tf.get_default_graph().as_graph_def().node:
if node.op.startswith('Relu'):
self.assertEqual(node.op == 'Relu6', use_bounded_activation)
if __name__ == '__main__':
tf.test.main()
......@@ -90,7 +90,7 @@ def upsample(net, num_outputs, kernel_size, method='nn_upsample_conv'):
net = layers.conv2d_transpose(
net, num_outputs, [4, 4], stride=kernel_size, activation_fn=None)
else:
raise ValueError('Unknown method: [%s]', method)
raise ValueError('Unknown method: [%s]' % method)
return net
......@@ -222,7 +222,8 @@ def pix2pix_generator(net,
return logits, end_points
def pix2pix_discriminator(net, num_filters, padding=2, is_training=False):
def pix2pix_discriminator(net, num_filters, padding=2, pad_mode='REFLECT',
activation_fn=tf.nn.leaky_relu, is_training=False):
"""Creates the Image2Image Translation Discriminator.
Args:
......@@ -231,6 +232,8 @@ def pix2pix_discriminator(net, num_filters, padding=2, is_training=False):
num_filters: A list of the filters in the discriminator. The length of the
list determines the number of layers in the discriminator.
padding: Amount of reflection padding applied before each convolution.
pad_mode: mode for tf.pad, one of "CONSTANT", "REFLECT", or "SYMMETRIC".
activation_fn: activation fn for layers.conv2d.
is_training: Whether or not the model is training or testing.
Returns:
......@@ -249,7 +252,7 @@ def pix2pix_discriminator(net, num_filters, padding=2, is_training=False):
spatial_pad = tf.constant(
[[0, 0], [padding, padding], [padding, padding], [0, 0]],
dtype=tf.int32)
return tf.pad(net, spatial_pad, 'REFLECT')
return tf.pad(net, spatial_pad, pad_mode)
else:
return net
......@@ -258,7 +261,7 @@ def pix2pix_discriminator(net, num_filters, padding=2, is_training=False):
kernel_size=[4, 4],
stride=2,
padding='valid',
activation_fn=tf.nn.leaky_relu):
activation_fn=activation_fn):
# No normalization on the input layer.
net = layers.conv2d(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment