Commit 2c680af3 authored by Sergio Guadarrama's avatar Sergio Guadarrama Committed by pkulzc
Browse files

Merged commit includes the following changes:

223150784  by Sergio Guadarrama:

    Allow using batch norm scale parameters for Inception models.

--
221391590  by Sergio Guadarrama:

    Add support for group normalization in the object detection API. Just adding MobileNet-v1 SSD currently. This may serve as a road map for other models that wish to support group normalization as an option.

--
221342582  by Sergio Guadarrama:

    Internal change

220817084  by Sergio Guadarrama:

    Internal change

216005108  by Sergio Guadarrama:

    Introduce hparam `use_bounded_activation` for NASNet. The hparam decides whether to use
    1. bounded activation
    2. clip_by_value for the add operands and bounded activation after add operator.
    3. bounded activation before 'none' and 'pooling' branch
    The restriction on the tensor value range makes it compatible with quantized inference.

--

PiperOrigin-RevId: 223150784
parent 5324fc66
...@@ -63,7 +63,8 @@ def cyclegan_arg_scope(instance_norm_center=True, ...@@ -63,7 +63,8 @@ def cyclegan_arg_scope(instance_norm_center=True,
return sc return sc
def cyclegan_upsample(net, num_outputs, stride, method='conv2d_transpose'): def cyclegan_upsample(net, num_outputs, stride, method='conv2d_transpose',
pad_mode='REFLECT', align_corners=False):
"""Upsamples the given inputs. """Upsamples the given inputs.
Args: Args:
...@@ -75,6 +76,10 @@ def cyclegan_upsample(net, num_outputs, stride, method='conv2d_transpose'): ...@@ -75,6 +76,10 @@ def cyclegan_upsample(net, num_outputs, stride, method='conv2d_transpose'):
times the input size. times the input size.
method: The upsampling method: 'nn_upsample_conv', 'bilinear_upsample_conv', method: The upsampling method: 'nn_upsample_conv', 'bilinear_upsample_conv',
or 'conv2d_transpose'. or 'conv2d_transpose'.
pad_mode: mode for tf.pad, one of "CONSTANT", "REFLECT", or "SYMMETRIC".
align_corners: option for method, 'bilinear_upsample_conv'. If true, the
centers of the 4 corner pixels of the input and output tensors are
aligned, preserving the values at the corner pixels.
Returns: Returns:
A Tensor which was upsampled using the specified method. A Tensor which was upsampled using the specified method.
...@@ -95,12 +100,13 @@ def cyclegan_upsample(net, num_outputs, stride, method='conv2d_transpose'): ...@@ -95,12 +100,13 @@ def cyclegan_upsample(net, num_outputs, stride, method='conv2d_transpose'):
if method == 'nn_upsample_conv': if method == 'nn_upsample_conv':
net = tf.image.resize_nearest_neighbor( net = tf.image.resize_nearest_neighbor(
net, [stride[0] * height, stride[1] * width]) net, [stride[0] * height, stride[1] * width])
net = tf.pad(net, spatial_pad_1, 'REFLECT') net = tf.pad(net, spatial_pad_1, pad_mode)
net = layers.conv2d(net, num_outputs, kernel_size=[3, 3], padding='valid') net = layers.conv2d(net, num_outputs, kernel_size=[3, 3], padding='valid')
elif method == 'bilinear_upsample_conv': elif method == 'bilinear_upsample_conv':
net = tf.image.resize_bilinear( net = tf.image.resize_bilinear(
net, [stride[0] * height, stride[1] * width]) net, [stride[0] * height, stride[1] * width],
net = tf.pad(net, spatial_pad_1, 'REFLECT') align_corners=align_corners)
net = tf.pad(net, spatial_pad_1, pad_mode)
net = layers.conv2d(net, num_outputs, kernel_size=[3, 3], padding='valid') net = layers.conv2d(net, num_outputs, kernel_size=[3, 3], padding='valid')
elif method == 'conv2d_transpose': elif method == 'conv2d_transpose':
# This corrects 1 pixel offset for images with even width and height. # This corrects 1 pixel offset for images with even width and height.
...@@ -111,7 +117,7 @@ def cyclegan_upsample(net, num_outputs, stride, method='conv2d_transpose'): ...@@ -111,7 +117,7 @@ def cyclegan_upsample(net, num_outputs, stride, method='conv2d_transpose'):
net, num_outputs, kernel_size=[3, 3], stride=stride, padding='valid') net, num_outputs, kernel_size=[3, 3], stride=stride, padding='valid')
net = net[:, 1:, 1:, :] net = net[:, 1:, 1:, :]
else: else:
raise ValueError('Unknown method: [%s]', method) raise ValueError('Unknown method: [%s]' % method)
return net return net
......
...@@ -370,7 +370,8 @@ def inception_resnet_v2_arg_scope( ...@@ -370,7 +370,8 @@ def inception_resnet_v2_arg_scope(
batch_norm_decay=0.9997, batch_norm_decay=0.9997,
batch_norm_epsilon=0.001, batch_norm_epsilon=0.001,
activation_fn=tf.nn.relu, activation_fn=tf.nn.relu,
batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS): batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS,
batch_norm_scale=False):
"""Returns the scope with the default parameters for inception_resnet_v2. """Returns the scope with the default parameters for inception_resnet_v2.
Args: Args:
...@@ -380,6 +381,8 @@ def inception_resnet_v2_arg_scope( ...@@ -380,6 +381,8 @@ def inception_resnet_v2_arg_scope(
activation_fn: Activation function for conv2d. activation_fn: Activation function for conv2d.
batch_norm_updates_collections: Collection for the update ops for batch_norm_updates_collections: Collection for the update ops for
batch norm. batch norm.
batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
activations in the batch normalization layer.
Returns: Returns:
a arg_scope with the parameters needed for inception_resnet_v2. a arg_scope with the parameters needed for inception_resnet_v2.
...@@ -394,6 +397,7 @@ def inception_resnet_v2_arg_scope( ...@@ -394,6 +397,7 @@ def inception_resnet_v2_arg_scope(
'epsilon': batch_norm_epsilon, 'epsilon': batch_norm_epsilon,
'updates_collections': batch_norm_updates_collections, 'updates_collections': batch_norm_updates_collections,
'fused': None, # Use fused batch norm if possible. 'fused': None, # Use fused batch norm if possible.
'scale': batch_norm_scale,
} }
# Set activation_fn and parameters for batch_norm. # Set activation_fn and parameters for batch_norm.
with slim.arg_scope([slim.conv2d], activation_fn=activation_fn, with slim.arg_scope([slim.conv2d], activation_fn=activation_fn,
......
...@@ -306,6 +306,29 @@ class InceptionTest(tf.test.TestCase): ...@@ -306,6 +306,29 @@ class InceptionTest(tf.test.TestCase):
output = sess.run(predictions) output = sess.run(predictions)
self.assertEquals(output.shape, (eval_batch_size,)) self.assertEquals(output.shape, (eval_batch_size,))
def testNoBatchNormScaleByDefault(self):
height, width = 299, 299
num_classes = 1000
inputs = tf.placeholder(tf.float32, (1, height, width, 3))
with tf.contrib.slim.arg_scope(inception.inception_resnet_v2_arg_scope()):
inception.inception_resnet_v2(inputs, num_classes, is_training=False)
self.assertEqual(tf.global_variables('.*/BatchNorm/gamma:0$'), [])
def testBatchNormScale(self):
height, width = 299, 299
num_classes = 1000
inputs = tf.placeholder(tf.float32, (1, height, width, 3))
with tf.contrib.slim.arg_scope(
inception.inception_resnet_v2_arg_scope(batch_norm_scale=True)):
inception.inception_resnet_v2(inputs, num_classes, is_training=False)
gamma_names = set(
v.op.name for v in tf.global_variables('.*/BatchNorm/gamma:0$'))
self.assertGreater(len(gamma_names), 0)
for v in tf.global_variables('.*/BatchNorm/moving_mean:0$'):
self.assertIn(v.op.name[:-len('moving_mean')] + 'gamma', gamma_names)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -34,7 +34,8 @@ def inception_arg_scope(weight_decay=0.00004, ...@@ -34,7 +34,8 @@ def inception_arg_scope(weight_decay=0.00004,
batch_norm_decay=0.9997, batch_norm_decay=0.9997,
batch_norm_epsilon=0.001, batch_norm_epsilon=0.001,
activation_fn=tf.nn.relu, activation_fn=tf.nn.relu,
batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS): batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS,
batch_norm_scale=False):
"""Defines the default arg scope for inception models. """Defines the default arg scope for inception models.
Args: Args:
...@@ -46,6 +47,8 @@ def inception_arg_scope(weight_decay=0.00004, ...@@ -46,6 +47,8 @@ def inception_arg_scope(weight_decay=0.00004,
activation_fn: Activation function for conv2d. activation_fn: Activation function for conv2d.
batch_norm_updates_collections: Collection for the update ops for batch_norm_updates_collections: Collection for the update ops for
batch norm. batch norm.
batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
activations in the batch normalization layer.
Returns: Returns:
An `arg_scope` to use for the inception models. An `arg_scope` to use for the inception models.
...@@ -59,6 +62,7 @@ def inception_arg_scope(weight_decay=0.00004, ...@@ -59,6 +62,7 @@ def inception_arg_scope(weight_decay=0.00004,
'updates_collections': batch_norm_updates_collections, 'updates_collections': batch_norm_updates_collections,
# use fused batch norm if possible. # use fused batch norm if possible.
'fused': None, 'fused': None,
'scale': batch_norm_scale,
} }
if use_batch_norm: if use_batch_norm:
normalizer_fn = slim.batch_norm normalizer_fn = slim.batch_norm
......
...@@ -237,6 +237,29 @@ class InceptionV1Test(tf.test.TestCase): ...@@ -237,6 +237,29 @@ class InceptionV1Test(tf.test.TestCase):
logits_out = sess.run(logits) logits_out = sess.run(logits)
self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes]) self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes])
def testNoBatchNormScaleByDefault(self):
height, width = 224, 224
num_classes = 1000
inputs = tf.placeholder(tf.float32, (1, height, width, 3))
with slim.arg_scope(inception.inception_v1_arg_scope()):
inception.inception_v1(inputs, num_classes, is_training=False)
self.assertEqual(tf.global_variables('.*/BatchNorm/gamma:0$'), [])
def testBatchNormScale(self):
height, width = 224, 224
num_classes = 1000
inputs = tf.placeholder(tf.float32, (1, height, width, 3))
with slim.arg_scope(
inception.inception_v1_arg_scope(batch_norm_scale=True)):
inception.inception_v1(inputs, num_classes, is_training=False)
gamma_names = set(
v.op.name for v in tf.global_variables('.*/BatchNorm/gamma:0$'))
self.assertGreater(len(gamma_names), 0)
for v in tf.global_variables('.*/BatchNorm/moving_mean:0$'):
self.assertIn(v.op.name[:-len('moving_mean')] + 'gamma', gamma_names)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -351,6 +351,29 @@ class InceptionV2Test(tf.test.TestCase): ...@@ -351,6 +351,29 @@ class InceptionV2Test(tf.test.TestCase):
logits_out = sess.run(logits) logits_out = sess.run(logits)
self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes]) self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes])
def testNoBatchNormScaleByDefault(self):
height, width = 224, 224
num_classes = 1000
inputs = tf.placeholder(tf.float32, (1, height, width, 3))
with slim.arg_scope(inception.inception_v2_arg_scope()):
inception.inception_v2(inputs, num_classes, is_training=False)
self.assertEqual(tf.global_variables('.*/BatchNorm/gamma:0$'), [])
def testBatchNormScale(self):
height, width = 224, 224
num_classes = 1000
inputs = tf.placeholder(tf.float32, (1, height, width, 3))
with slim.arg_scope(
inception.inception_v2_arg_scope(batch_norm_scale=True)):
inception.inception_v2(inputs, num_classes, is_training=False)
gamma_names = set(
v.op.name for v in tf.global_variables('.*/BatchNorm/gamma:0$'))
self.assertGreater(len(gamma_names), 0)
for v in tf.global_variables('.*/BatchNorm/moving_mean:0$'):
self.assertIn(v.op.name[:-len('moving_mean')] + 'gamma', gamma_names)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -318,6 +318,29 @@ class InceptionV3Test(tf.test.TestCase): ...@@ -318,6 +318,29 @@ class InceptionV3Test(tf.test.TestCase):
logits_out = sess.run(logits) logits_out = sess.run(logits)
self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes]) self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes])
def testNoBatchNormScaleByDefault(self):
height, width = 299, 299
num_classes = 1000
inputs = tf.placeholder(tf.float32, (1, height, width, 3))
with slim.arg_scope(inception.inception_v3_arg_scope()):
inception.inception_v3(inputs, num_classes, is_training=False)
self.assertEqual(tf.global_variables('.*/BatchNorm/gamma:0$'), [])
def testBatchNormScale(self):
height, width = 299, 299
num_classes = 1000
inputs = tf.placeholder(tf.float32, (1, height, width, 3))
with slim.arg_scope(
inception.inception_v3_arg_scope(batch_norm_scale=True)):
inception.inception_v3(inputs, num_classes, is_training=False)
gamma_names = set(
v.op.name for v in tf.global_variables('.*/BatchNorm/gamma:0$'))
self.assertGreater(len(gamma_names), 0)
for v in tf.global_variables('.*/BatchNorm/moving_mean:0$'):
self.assertIn(v.op.name[:-len('moving_mean')] + 'gamma', gamma_names)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -255,6 +255,29 @@ class InceptionTest(tf.test.TestCase): ...@@ -255,6 +255,29 @@ class InceptionTest(tf.test.TestCase):
output = sess.run(predictions) output = sess.run(predictions)
self.assertEquals(output.shape, (eval_batch_size,)) self.assertEquals(output.shape, (eval_batch_size,))
def testNoBatchNormScaleByDefault(self):
height, width = 299, 299
num_classes = 1000
inputs = tf.placeholder(tf.float32, (1, height, width, 3))
with tf.contrib.slim.arg_scope(inception.inception_v4_arg_scope()):
inception.inception_v4(inputs, num_classes, is_training=False)
self.assertEqual(tf.global_variables('.*/BatchNorm/gamma:0$'), [])
def testBatchNormScale(self):
height, width = 299, 299
num_classes = 1000
inputs = tf.placeholder(tf.float32, (1, height, width, 3))
with tf.contrib.slim.arg_scope(
inception.inception_v4_arg_scope(batch_norm_scale=True)):
inception.inception_v4(inputs, num_classes, is_training=False)
gamma_names = set(
v.op.name for v in tf.global_variables('.*/BatchNorm/gamma:0$'))
self.assertGreater(len(gamma_names), 0)
for v in tf.global_variables('.*/BatchNorm/moving_mean:0$'):
self.assertIn(v.op.name[:-len('moving_mean')] + 'gamma', gamma_names)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -263,7 +263,6 @@ def mobilenet_v1_base(inputs, ...@@ -263,7 +263,6 @@ def mobilenet_v1_base(inputs,
net = _fixed_padding(net, conv_def.kernel) net = _fixed_padding(net, conv_def.kernel)
net = slim.conv2d(net, depth(conv_def.depth), conv_def.kernel, net = slim.conv2d(net, depth(conv_def.depth), conv_def.kernel,
stride=conv_def.stride, stride=conv_def.stride,
normalizer_fn=slim.batch_norm,
scope=end_point) scope=end_point)
end_points[end_point] = net end_points[end_point] = net
if end_point == final_endpoint: if end_point == final_endpoint:
...@@ -280,7 +279,6 @@ def mobilenet_v1_base(inputs, ...@@ -280,7 +279,6 @@ def mobilenet_v1_base(inputs,
depth_multiplier=1, depth_multiplier=1,
stride=layer_stride, stride=layer_stride,
rate=layer_rate, rate=layer_rate,
normalizer_fn=slim.batch_norm,
scope=end_point) scope=end_point)
end_points[end_point] = net end_points[end_point] = net
...@@ -291,7 +289,6 @@ def mobilenet_v1_base(inputs, ...@@ -291,7 +289,6 @@ def mobilenet_v1_base(inputs,
net = slim.conv2d(net, depth(conv_def.depth), [1, 1], net = slim.conv2d(net, depth(conv_def.depth), [1, 1],
stride=1, stride=1,
normalizer_fn=slim.batch_norm,
scope=end_point) scope=end_point)
end_points[end_point] = net end_points[end_point] = net
...@@ -432,7 +429,8 @@ def mobilenet_v1_arg_scope( ...@@ -432,7 +429,8 @@ def mobilenet_v1_arg_scope(
regularize_depthwise=False, regularize_depthwise=False,
batch_norm_decay=0.9997, batch_norm_decay=0.9997,
batch_norm_epsilon=0.001, batch_norm_epsilon=0.001,
batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS): batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS,
normalizer_fn=slim.batch_norm):
"""Defines the default MobilenetV1 arg scope. """Defines the default MobilenetV1 arg scope.
Args: Args:
...@@ -446,6 +444,7 @@ def mobilenet_v1_arg_scope( ...@@ -446,6 +444,7 @@ def mobilenet_v1_arg_scope(
in batch norm. in batch norm.
batch_norm_updates_collections: Collection for the update ops for batch_norm_updates_collections: Collection for the update ops for
batch norm. batch norm.
normalizer_fn: Normalization function to apply after convolution.
Returns: Returns:
An `arg_scope` to use for the mobilenet v1 model. An `arg_scope` to use for the mobilenet v1 model.
...@@ -469,7 +468,7 @@ def mobilenet_v1_arg_scope( ...@@ -469,7 +468,7 @@ def mobilenet_v1_arg_scope(
depthwise_regularizer = None depthwise_regularizer = None
with slim.arg_scope([slim.conv2d, slim.separable_conv2d], with slim.arg_scope([slim.conv2d, slim.separable_conv2d],
weights_initializer=weights_init, weights_initializer=weights_init,
activation_fn=tf.nn.relu6, normalizer_fn=slim.batch_norm): activation_fn=tf.nn.relu6, normalizer_fn=normalizer_fn):
with slim.arg_scope([slim.batch_norm], **batch_norm_params): with slim.arg_scope([slim.batch_norm], **batch_norm_params):
with slim.arg_scope([slim.conv2d], weights_regularizer=regularizer): with slim.arg_scope([slim.conv2d], weights_regularizer=regularizer):
with slim.arg_scope([slim.separable_conv2d], with slim.arg_scope([slim.separable_conv2d],
......
...@@ -52,6 +52,7 @@ def cifar_config(): ...@@ -52,6 +52,7 @@ def cifar_config():
# This is used for the drop path probabilities since it needs to increase # This is used for the drop path probabilities since it needs to increase
# the drop out probability over the course of training. # the drop out probability over the course of training.
total_training_steps=937500, total_training_steps=937500,
use_bounded_activation=False,
) )
...@@ -78,6 +79,7 @@ def large_imagenet_config(): ...@@ -78,6 +79,7 @@ def large_imagenet_config():
data_format='NHWC', data_format='NHWC',
skip_reduction_layer_input=1, skip_reduction_layer_input=1,
total_training_steps=250000, total_training_steps=250000,
use_bounded_activation=False,
) )
...@@ -104,6 +106,7 @@ def mobile_imagenet_config(): ...@@ -104,6 +106,7 @@ def mobile_imagenet_config():
data_format='NHWC', data_format='NHWC',
skip_reduction_layer_input=0, skip_reduction_layer_input=0,
total_training_steps=250000, total_training_steps=250000,
use_bounded_activation=False,
) )
...@@ -223,6 +226,7 @@ def nasnet_large_arg_scope(weight_decay=5e-5, ...@@ -223,6 +226,7 @@ def nasnet_large_arg_scope(weight_decay=5e-5,
def _build_aux_head(net, end_points, num_classes, hparams, scope): def _build_aux_head(net, end_points, num_classes, hparams, scope):
"""Auxiliary head used for all models across all datasets.""" """Auxiliary head used for all models across all datasets."""
activation_fn = tf.nn.relu6 if hparams.use_bounded_activation else tf.nn.relu
with tf.variable_scope(scope): with tf.variable_scope(scope):
aux_logits = tf.identity(net) aux_logits = tf.identity(net)
with tf.variable_scope('aux_logits'): with tf.variable_scope('aux_logits'):
...@@ -230,7 +234,7 @@ def _build_aux_head(net, end_points, num_classes, hparams, scope): ...@@ -230,7 +234,7 @@ def _build_aux_head(net, end_points, num_classes, hparams, scope):
aux_logits, [5, 5], stride=3, padding='VALID') aux_logits, [5, 5], stride=3, padding='VALID')
aux_logits = slim.conv2d(aux_logits, 128, [1, 1], scope='proj') aux_logits = slim.conv2d(aux_logits, 128, [1, 1], scope='proj')
aux_logits = slim.batch_norm(aux_logits, scope='aux_bn0') aux_logits = slim.batch_norm(aux_logits, scope='aux_bn0')
aux_logits = tf.nn.relu(aux_logits) aux_logits = activation_fn(aux_logits)
# Shape of feature map before the final layer. # Shape of feature map before the final layer.
shape = aux_logits.shape shape = aux_logits.shape
if hparams.data_format == 'NHWC': if hparams.data_format == 'NHWC':
...@@ -239,7 +243,7 @@ def _build_aux_head(net, end_points, num_classes, hparams, scope): ...@@ -239,7 +243,7 @@ def _build_aux_head(net, end_points, num_classes, hparams, scope):
shape = shape[2:4] shape = shape[2:4]
aux_logits = slim.conv2d(aux_logits, 768, shape, padding='VALID') aux_logits = slim.conv2d(aux_logits, 768, shape, padding='VALID')
aux_logits = slim.batch_norm(aux_logits, scope='aux_bn1') aux_logits = slim.batch_norm(aux_logits, scope='aux_bn1')
aux_logits = tf.nn.relu(aux_logits) aux_logits = activation_fn(aux_logits)
aux_logits = tf.contrib.layers.flatten(aux_logits) aux_logits = tf.contrib.layers.flatten(aux_logits)
aux_logits = slim.fully_connected(aux_logits, num_classes) aux_logits = slim.fully_connected(aux_logits, num_classes)
end_points['AuxLogits'] = aux_logits end_points['AuxLogits'] = aux_logits
...@@ -306,10 +310,12 @@ def build_nasnet_cifar(images, num_classes, ...@@ -306,10 +310,12 @@ def build_nasnet_cifar(images, num_classes,
normal_cell = nasnet_utils.NasNetANormalCell( normal_cell = nasnet_utils.NasNetANormalCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob, hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps) total_num_cells, hparams.total_training_steps,
hparams.use_bounded_activation)
reduction_cell = nasnet_utils.NasNetAReductionCell( reduction_cell = nasnet_utils.NasNetAReductionCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob, hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps) total_num_cells, hparams.total_training_steps,
hparams.use_bounded_activation)
with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm], with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
is_training=is_training): is_training=is_training):
with arg_scope([slim.avg_pool2d, with arg_scope([slim.avg_pool2d,
...@@ -358,10 +364,12 @@ def build_nasnet_mobile(images, num_classes, ...@@ -358,10 +364,12 @@ def build_nasnet_mobile(images, num_classes,
normal_cell = nasnet_utils.NasNetANormalCell( normal_cell = nasnet_utils.NasNetANormalCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob, hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps) total_num_cells, hparams.total_training_steps,
hparams.use_bounded_activation)
reduction_cell = nasnet_utils.NasNetAReductionCell( reduction_cell = nasnet_utils.NasNetAReductionCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob, hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps) total_num_cells, hparams.total_training_steps,
hparams.use_bounded_activation)
with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm], with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
is_training=is_training): is_training=is_training):
with arg_scope([slim.avg_pool2d, with arg_scope([slim.avg_pool2d,
...@@ -411,10 +419,12 @@ def build_nasnet_large(images, num_classes, ...@@ -411,10 +419,12 @@ def build_nasnet_large(images, num_classes,
normal_cell = nasnet_utils.NasNetANormalCell( normal_cell = nasnet_utils.NasNetANormalCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob, hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps) total_num_cells, hparams.total_training_steps,
hparams.use_bounded_activation)
reduction_cell = nasnet_utils.NasNetAReductionCell( reduction_cell = nasnet_utils.NasNetAReductionCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob, hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps) total_num_cells, hparams.total_training_steps,
hparams.use_bounded_activation)
with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm], with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
is_training=is_training): is_training=is_training):
with arg_scope([slim.avg_pool2d, with arg_scope([slim.avg_pool2d,
...@@ -478,6 +488,7 @@ def _build_nasnet_base(images, ...@@ -478,6 +488,7 @@ def _build_nasnet_base(images,
filter_scaling = 1.0 filter_scaling = 1.0
# true_cell_num accounts for the stem cells # true_cell_num accounts for the stem cells
true_cell_num = 2 if stem_type == 'imagenet' else 0 true_cell_num = 2 if stem_type == 'imagenet' else 0
activation_fn = tf.nn.relu6 if hparams.use_bounded_activation else tf.nn.relu
for cell_num in range(hparams.num_cells): for cell_num in range(hparams.num_cells):
stride = 1 stride = 1
if hparams.skip_reduction_layer_input: if hparams.skip_reduction_layer_input:
...@@ -513,14 +524,14 @@ def _build_nasnet_base(images, ...@@ -513,14 +524,14 @@ def _build_nasnet_base(images,
true_cell_num += 1 true_cell_num += 1
if (hparams.use_aux_head and cell_num in aux_head_cell_idxes and if (hparams.use_aux_head and cell_num in aux_head_cell_idxes and
num_classes and is_training): num_classes and is_training):
aux_net = tf.nn.relu(net) aux_net = activation_fn(net)
_build_aux_head(aux_net, end_points, num_classes, hparams, _build_aux_head(aux_net, end_points, num_classes, hparams,
scope='aux_{}'.format(cell_num)) scope='aux_{}'.format(cell_num))
cell_outputs.append(net) cell_outputs.append(net)
# Final softmax layer # Final softmax layer
with tf.variable_scope('final_layer'): with tf.variable_scope('final_layer'):
net = tf.nn.relu(net) net = activation_fn(net)
net = nasnet_utils.global_avg_pool(net) net = nasnet_utils.global_avg_pool(net)
if add_and_check_endpoint('global_pool', net) or not num_classes: if add_and_check_endpoint('global_pool', net) or not num_classes:
return net, end_points return net, end_points
......
...@@ -390,5 +390,21 @@ class NASNetTest(tf.test.TestCase): ...@@ -390,5 +390,21 @@ class NASNetTest(tf.test.TestCase):
self.assertListEqual(predictions.get_shape().as_list(), self.assertListEqual(predictions.get_shape().as_list(),
[batch_size, num_classes]) [batch_size, num_classes])
def testUseBoundedAcitvationCifarModel(self):
batch_size = 1
height, width = 32, 32
num_classes = 10
for use_bounded_activation in (True, False):
tf.reset_default_graph()
inputs = tf.random_uniform((batch_size, height, width, 3))
config = nasnet.cifar_config()
config.set_hparam('use_bounded_activation', use_bounded_activation)
with slim.arg_scope(nasnet.nasnet_cifar_arg_scope()):
_, _ = nasnet.build_nasnet_cifar(
inputs, num_classes, config=config)
for node in tf.get_default_graph().as_graph_def().node:
if node.op.startswith('Relu'):
self.assertEqual(node.op == 'Relu6', use_bounded_activation)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -40,6 +40,9 @@ slim = tf.contrib.slim ...@@ -40,6 +40,9 @@ slim = tf.contrib.slim
DATA_FORMAT_NCHW = 'NCHW' DATA_FORMAT_NCHW = 'NCHW'
DATA_FORMAT_NHWC = 'NHWC' DATA_FORMAT_NHWC = 'NHWC'
INVALID = 'null' INVALID = 'null'
# The cap for tf.clip_by_value, it's hinted from the activation distribution
# that the majority of activation values are in the range [-6, 6].
CLIP_BY_VALUE_CAP = 6
def calc_reduction_layers(num_cells, num_reduction_layers): def calc_reduction_layers(num_cells, num_reduction_layers):
...@@ -172,11 +175,13 @@ def _operation_to_info(operation): ...@@ -172,11 +175,13 @@ def _operation_to_info(operation):
return num_layers, filter_shape return num_layers, filter_shape
def _stacked_separable_conv(net, stride, operation, filter_size): def _stacked_separable_conv(net, stride, operation, filter_size,
use_bounded_activation):
"""Takes in an operations and parses it to the correct sep operation.""" """Takes in an operations and parses it to the correct sep operation."""
num_layers, kernel_size = _operation_to_info(operation) num_layers, kernel_size = _operation_to_info(operation)
activation_fn = tf.nn.relu6 if use_bounded_activation else tf.nn.relu
for layer_num in range(num_layers - 1): for layer_num in range(num_layers - 1):
net = tf.nn.relu(net) net = activation_fn(net)
net = slim.separable_conv2d( net = slim.separable_conv2d(
net, net,
filter_size, filter_size,
...@@ -187,7 +192,7 @@ def _stacked_separable_conv(net, stride, operation, filter_size): ...@@ -187,7 +192,7 @@ def _stacked_separable_conv(net, stride, operation, filter_size):
net = slim.batch_norm( net = slim.batch_norm(
net, scope='bn_sep_{0}x{0}_{1}'.format(kernel_size, layer_num + 1)) net, scope='bn_sep_{0}x{0}_{1}'.format(kernel_size, layer_num + 1))
stride = 1 stride = 1
net = tf.nn.relu(net) net = activation_fn(net)
net = slim.separable_conv2d( net = slim.separable_conv2d(
net, net,
filter_size, filter_size,
...@@ -223,10 +228,12 @@ def _operation_to_pooling_info(operation): ...@@ -223,10 +228,12 @@ def _operation_to_pooling_info(operation):
return pooling_type, pooling_shape return pooling_type, pooling_shape
def _pooling(net, stride, operation): def _pooling(net, stride, operation, use_bounded_activation):
"""Parses operation and performs the correct pooling operation on net.""" """Parses operation and performs the correct pooling operation on net."""
padding = 'SAME' padding = 'SAME'
pooling_type, pooling_shape = _operation_to_pooling_info(operation) pooling_type, pooling_shape = _operation_to_pooling_info(operation)
if use_bounded_activation:
net = tf.nn.relu6(net)
if pooling_type == 'avg': if pooling_type == 'avg':
net = slim.avg_pool2d(net, pooling_shape, stride=stride, padding=padding) net = slim.avg_pool2d(net, pooling_shape, stride=stride, padding=padding)
elif pooling_type == 'max': elif pooling_type == 'max':
...@@ -248,11 +255,13 @@ class NasNetABaseCell(object): ...@@ -248,11 +255,13 @@ class NasNetABaseCell(object):
should be concatenated together. should be concatenated together.
hiddenstate_indices: Determines what hiddenstates should be combined hiddenstate_indices: Determines what hiddenstates should be combined
together with the specified operations to create the NASNet cell. together with the specified operations to create the NASNet cell.
use_bounded_activation: Whether or not to use bounded activations. Bounded
activations better lend themselves to quantized inference.
""" """
def __init__(self, num_conv_filters, operations, used_hiddenstates, def __init__(self, num_conv_filters, operations, used_hiddenstates,
hiddenstate_indices, drop_path_keep_prob, total_num_cells, hiddenstate_indices, drop_path_keep_prob, total_num_cells,
total_training_steps): total_training_steps, use_bounded_activation=False):
self._num_conv_filters = num_conv_filters self._num_conv_filters = num_conv_filters
self._operations = operations self._operations = operations
self._used_hiddenstates = used_hiddenstates self._used_hiddenstates = used_hiddenstates
...@@ -260,6 +269,7 @@ class NasNetABaseCell(object): ...@@ -260,6 +269,7 @@ class NasNetABaseCell(object):
self._drop_path_keep_prob = drop_path_keep_prob self._drop_path_keep_prob = drop_path_keep_prob
self._total_num_cells = total_num_cells self._total_num_cells = total_num_cells
self._total_training_steps = total_training_steps self._total_training_steps = total_training_steps
self._use_bounded_activation = use_bounded_activation
def _reduce_prev_layer(self, prev_layer, curr_layer): def _reduce_prev_layer(self, prev_layer, curr_layer):
"""Matches dimension of prev_layer to the curr_layer.""" """Matches dimension of prev_layer to the curr_layer."""
...@@ -270,12 +280,13 @@ class NasNetABaseCell(object): ...@@ -270,12 +280,13 @@ class NasNetABaseCell(object):
prev_num_filters = get_channel_dim(prev_layer.shape) prev_num_filters = get_channel_dim(prev_layer.shape)
curr_filter_shape = int(curr_layer.shape[2]) curr_filter_shape = int(curr_layer.shape[2])
prev_filter_shape = int(prev_layer.shape[2]) prev_filter_shape = int(prev_layer.shape[2])
activation_fn = tf.nn.relu6 if self._use_bounded_activation else tf.nn.relu
if curr_filter_shape != prev_filter_shape: if curr_filter_shape != prev_filter_shape:
prev_layer = tf.nn.relu(prev_layer) prev_layer = activation_fn(prev_layer)
prev_layer = factorized_reduction( prev_layer = factorized_reduction(
prev_layer, curr_num_filters, stride=2) prev_layer, curr_num_filters, stride=2)
elif curr_num_filters != prev_num_filters: elif curr_num_filters != prev_num_filters:
prev_layer = tf.nn.relu(prev_layer) prev_layer = activation_fn(prev_layer)
prev_layer = slim.conv2d( prev_layer = slim.conv2d(
prev_layer, curr_num_filters, 1, scope='prev_1x1') prev_layer, curr_num_filters, 1, scope='prev_1x1')
prev_layer = slim.batch_norm(prev_layer, scope='prev_bn') prev_layer = slim.batch_norm(prev_layer, scope='prev_bn')
...@@ -288,14 +299,11 @@ class NasNetABaseCell(object): ...@@ -288,14 +299,11 @@ class NasNetABaseCell(object):
# Check to be sure prev layer stuff is setup correctly # Check to be sure prev layer stuff is setup correctly
prev_layer = self._reduce_prev_layer(prev_layer, net) prev_layer = self._reduce_prev_layer(prev_layer, net)
net = tf.nn.relu(net) net = tf.nn.relu6(net) if self._use_bounded_activation else tf.nn.relu(net)
net = slim.conv2d(net, num_filters, 1, scope='1x1') net = slim.conv2d(net, num_filters, 1, scope='1x1')
net = slim.batch_norm(net, scope='beginning_bn') net = slim.batch_norm(net, scope='beginning_bn')
split_axis = get_channel_index() # num_or_size_splits=1
net = tf.split(axis=split_axis, num_or_size_splits=1, value=net) net = [net]
for split in net:
assert int(split.shape[split_axis] == int(self._num_conv_filters *
self._filter_scaling))
net.append(prev_layer) net.append(prev_layer)
return net return net
...@@ -335,6 +343,8 @@ class NasNetABaseCell(object): ...@@ -335,6 +343,8 @@ class NasNetABaseCell(object):
# Combine hidden states using 'add'. # Combine hidden states using 'add'.
with tf.variable_scope('combine'): with tf.variable_scope('combine'):
h = h1 + h2 h = h1 + h2
if self._use_bounded_activation:
h = tf.nn.relu6(h)
# Add hiddenstate to the list of hiddenstates we can choose from # Add hiddenstate to the list of hiddenstates we can choose from
net.append(h) net.append(h)
...@@ -353,18 +363,28 @@ class NasNetABaseCell(object): ...@@ -353,18 +363,28 @@ class NasNetABaseCell(object):
input_filters = get_channel_dim(net.shape) input_filters = get_channel_dim(net.shape)
filter_size = self._filter_size filter_size = self._filter_size
if 'separable' in operation: if 'separable' in operation:
net = _stacked_separable_conv(net, stride, operation, filter_size) net = _stacked_separable_conv(net, stride, operation, filter_size,
self._use_bounded_activation)
if self._use_bounded_activation:
net = tf.clip_by_value(net, -CLIP_BY_VALUE_CAP, CLIP_BY_VALUE_CAP)
elif operation in ['none']: elif operation in ['none']:
if self._use_bounded_activation:
net = tf.nn.relu6(net)
# Check if a stride is needed, then use a strided 1x1 here # Check if a stride is needed, then use a strided 1x1 here
if stride > 1 or (input_filters != filter_size): if stride > 1 or (input_filters != filter_size):
net = tf.nn.relu(net) if not self._use_bounded_activation:
net = tf.nn.relu(net)
net = slim.conv2d(net, filter_size, 1, stride=stride, scope='1x1') net = slim.conv2d(net, filter_size, 1, stride=stride, scope='1x1')
net = slim.batch_norm(net, scope='bn_1') net = slim.batch_norm(net, scope='bn_1')
if self._use_bounded_activation:
net = tf.clip_by_value(net, -CLIP_BY_VALUE_CAP, CLIP_BY_VALUE_CAP)
elif 'pool' in operation: elif 'pool' in operation:
net = _pooling(net, stride, operation) net = _pooling(net, stride, operation, self._use_bounded_activation)
if input_filters != filter_size: if input_filters != filter_size:
net = slim.conv2d(net, filter_size, 1, stride=1, scope='1x1') net = slim.conv2d(net, filter_size, 1, stride=1, scope='1x1')
net = slim.batch_norm(net, scope='bn_1') net = slim.batch_norm(net, scope='bn_1')
if self._use_bounded_activation:
net = tf.clip_by_value(net, -CLIP_BY_VALUE_CAP, CLIP_BY_VALUE_CAP)
else: else:
raise ValueError('Unimplemented operation', operation) raise ValueError('Unimplemented operation', operation)
...@@ -456,7 +476,7 @@ class NasNetANormalCell(NasNetABaseCell): ...@@ -456,7 +476,7 @@ class NasNetANormalCell(NasNetABaseCell):
"""NASNetA Normal Cell.""" """NASNetA Normal Cell."""
def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells, def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells,
total_training_steps): total_training_steps, use_bounded_activation=False):
operations = ['separable_5x5_2', operations = ['separable_5x5_2',
'separable_3x3_2', 'separable_3x3_2',
'separable_5x5_2', 'separable_5x5_2',
...@@ -474,14 +494,15 @@ class NasNetANormalCell(NasNetABaseCell): ...@@ -474,14 +494,15 @@ class NasNetANormalCell(NasNetABaseCell):
hiddenstate_indices, hiddenstate_indices,
drop_path_keep_prob, drop_path_keep_prob,
total_num_cells, total_num_cells,
total_training_steps) total_training_steps,
use_bounded_activation)
class NasNetAReductionCell(NasNetABaseCell): class NasNetAReductionCell(NasNetABaseCell):
"""NASNetA Reduction Cell.""" """NASNetA Reduction Cell."""
def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells, def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells,
total_training_steps): total_training_steps, use_bounded_activation=False):
operations = ['separable_5x5_2', operations = ['separable_5x5_2',
'separable_7x7_2', 'separable_7x7_2',
'max_pool_3x3', 'max_pool_3x3',
...@@ -499,4 +520,5 @@ class NasNetAReductionCell(NasNetABaseCell): ...@@ -499,4 +520,5 @@ class NasNetAReductionCell(NasNetABaseCell):
hiddenstate_indices, hiddenstate_indices,
drop_path_keep_prob, drop_path_keep_prob,
total_num_cells, total_num_cells,
total_training_steps) total_training_steps,
use_bounded_activation)
...@@ -45,6 +45,7 @@ def large_imagenet_config(): ...@@ -45,6 +45,7 @@ def large_imagenet_config():
data_format='NHWC', data_format='NHWC',
skip_reduction_layer_input=1, skip_reduction_layer_input=1,
total_training_steps=250000, total_training_steps=250000,
use_bounded_activation=False,
) )
...@@ -62,6 +63,7 @@ def mobile_imagenet_config(): ...@@ -62,6 +63,7 @@ def mobile_imagenet_config():
data_format='NHWC', data_format='NHWC',
skip_reduction_layer_input=1, skip_reduction_layer_input=1,
total_training_steps=250000, total_training_steps=250000,
use_bounded_activation=False,
) )
...@@ -114,6 +116,7 @@ def _build_pnasnet_base(images, ...@@ -114,6 +116,7 @@ def _build_pnasnet_base(images,
filter_scaling = 1.0 filter_scaling = 1.0
# true_cell_num accounts for the stem cells # true_cell_num accounts for the stem cells
true_cell_num = 2 true_cell_num = 2
activation_fn = tf.nn.relu6 if hparams.use_bounded_activation else tf.nn.relu
for cell_num in range(hparams.num_cells): for cell_num in range(hparams.num_cells):
is_reduction = cell_num in reduction_indices is_reduction = cell_num in reduction_indices
stride = 2 if is_reduction else 1 stride = 2 if is_reduction else 1
...@@ -134,7 +137,7 @@ def _build_pnasnet_base(images, ...@@ -134,7 +137,7 @@ def _build_pnasnet_base(images,
if (hparams.use_aux_head and cell_num in aux_head_cell_idxes and if (hparams.use_aux_head and cell_num in aux_head_cell_idxes and
num_classes and is_training): num_classes and is_training):
aux_net = tf.nn.relu(net) aux_net = activation_fn(net)
# pylint: disable=protected-access # pylint: disable=protected-access
nasnet._build_aux_head(aux_net, end_points, num_classes, hparams, nasnet._build_aux_head(aux_net, end_points, num_classes, hparams,
scope='aux_{}'.format(cell_num)) scope='aux_{}'.format(cell_num))
...@@ -142,7 +145,7 @@ def _build_pnasnet_base(images, ...@@ -142,7 +145,7 @@ def _build_pnasnet_base(images,
# Final softmax layer # Final softmax layer
with tf.variable_scope('final_layer'): with tf.variable_scope('final_layer'):
net = tf.nn.relu(net) net = activation_fn(net)
net = nasnet_utils.global_avg_pool(net) net = nasnet_utils.global_avg_pool(net)
if add_and_check_endpoint('global_pool', net) or not num_classes: if add_and_check_endpoint('global_pool', net) or not num_classes:
return net, end_points return net, end_points
...@@ -184,7 +187,8 @@ def build_pnasnet_large(images, ...@@ -184,7 +187,8 @@ def build_pnasnet_large(images,
normal_cell = PNasNetNormalCell(hparams.num_conv_filters, normal_cell = PNasNetNormalCell(hparams.num_conv_filters,
hparams.drop_path_keep_prob, total_num_cells, hparams.drop_path_keep_prob, total_num_cells,
hparams.total_training_steps) hparams.total_training_steps,
hparams.use_bounded_activation)
with arg_scope( with arg_scope(
[slim.dropout, nasnet_utils.drop_path, slim.batch_norm], [slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
is_training=is_training): is_training=is_training):
...@@ -231,7 +235,8 @@ def build_pnasnet_mobile(images, ...@@ -231,7 +235,8 @@ def build_pnasnet_mobile(images,
normal_cell = PNasNetNormalCell(hparams.num_conv_filters, normal_cell = PNasNetNormalCell(hparams.num_conv_filters,
hparams.drop_path_keep_prob, total_num_cells, hparams.drop_path_keep_prob, total_num_cells,
hparams.total_training_steps) hparams.total_training_steps,
hparams.use_bounded_activation)
with arg_scope( with arg_scope(
[slim.dropout, nasnet_utils.drop_path, slim.batch_norm], [slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
is_training=is_training): is_training=is_training):
...@@ -259,7 +264,7 @@ class PNasNetNormalCell(nasnet_utils.NasNetABaseCell): ...@@ -259,7 +264,7 @@ class PNasNetNormalCell(nasnet_utils.NasNetABaseCell):
"""PNASNet Normal Cell.""" """PNASNet Normal Cell."""
def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells, def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells,
total_training_steps): total_training_steps, use_bounded_activation=False):
# Configuration for the PNASNet-5 model. # Configuration for the PNASNet-5 model.
operations = [ operations = [
'separable_5x5_2', 'max_pool_3x3', 'separable_7x7_2', 'max_pool_3x3', 'separable_5x5_2', 'max_pool_3x3', 'separable_7x7_2', 'max_pool_3x3',
...@@ -271,4 +276,5 @@ class PNasNetNormalCell(nasnet_utils.NasNetABaseCell): ...@@ -271,4 +276,5 @@ class PNasNetNormalCell(nasnet_utils.NasNetABaseCell):
super(PNasNetNormalCell, self).__init__( super(PNasNetNormalCell, self).__init__(
num_conv_filters, operations, used_hiddenstates, hiddenstate_indices, num_conv_filters, operations, used_hiddenstates, hiddenstate_indices,
drop_path_keep_prob, total_num_cells, total_training_steps) drop_path_keep_prob, total_num_cells, total_training_steps,
use_bounded_activation)
...@@ -236,6 +236,21 @@ class PNASNetTest(tf.test.TestCase): ...@@ -236,6 +236,21 @@ class PNASNetTest(tf.test.TestCase):
self.assertListEqual(end_points['Stem'].shape.as_list(), self.assertListEqual(end_points['Stem'].shape.as_list(),
[batch_size, 135, 28, 28]) [batch_size, 135, 28, 28])
def testUseBoundedAcitvationMobileModel(self):
batch_size = 1
height, width = 224, 224
num_classes = 1000
for use_bounded_activation in (True, False):
tf.reset_default_graph()
inputs = tf.random_uniform((batch_size, height, width, 3))
config = pnasnet.mobile_imagenet_config()
config.set_hparam('use_bounded_activation', use_bounded_activation)
with slim.arg_scope(pnasnet.pnasnet_mobile_arg_scope()):
_, _ = pnasnet.build_pnasnet_mobile(
inputs, num_classes, config=config)
for node in tf.get_default_graph().as_graph_def().node:
if node.op.startswith('Relu'):
self.assertEqual(node.op == 'Relu6', use_bounded_activation)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -90,7 +90,7 @@ def upsample(net, num_outputs, kernel_size, method='nn_upsample_conv'): ...@@ -90,7 +90,7 @@ def upsample(net, num_outputs, kernel_size, method='nn_upsample_conv'):
net = layers.conv2d_transpose( net = layers.conv2d_transpose(
net, num_outputs, [4, 4], stride=kernel_size, activation_fn=None) net, num_outputs, [4, 4], stride=kernel_size, activation_fn=None)
else: else:
raise ValueError('Unknown method: [%s]', method) raise ValueError('Unknown method: [%s]' % method)
return net return net
...@@ -222,7 +222,8 @@ def pix2pix_generator(net, ...@@ -222,7 +222,8 @@ def pix2pix_generator(net,
return logits, end_points return logits, end_points
def pix2pix_discriminator(net, num_filters, padding=2, is_training=False): def pix2pix_discriminator(net, num_filters, padding=2, pad_mode='REFLECT',
activation_fn=tf.nn.leaky_relu, is_training=False):
"""Creates the Image2Image Translation Discriminator. """Creates the Image2Image Translation Discriminator.
Args: Args:
...@@ -231,6 +232,8 @@ def pix2pix_discriminator(net, num_filters, padding=2, is_training=False): ...@@ -231,6 +232,8 @@ def pix2pix_discriminator(net, num_filters, padding=2, is_training=False):
num_filters: A list of the filters in the discriminator. The length of the num_filters: A list of the filters in the discriminator. The length of the
list determines the number of layers in the discriminator. list determines the number of layers in the discriminator.
padding: Amount of reflection padding applied before each convolution. padding: Amount of reflection padding applied before each convolution.
pad_mode: mode for tf.pad, one of "CONSTANT", "REFLECT", or "SYMMETRIC".
activation_fn: activation fn for layers.conv2d.
is_training: Whether or not the model is training or testing. is_training: Whether or not the model is training or testing.
Returns: Returns:
...@@ -249,7 +252,7 @@ def pix2pix_discriminator(net, num_filters, padding=2, is_training=False): ...@@ -249,7 +252,7 @@ def pix2pix_discriminator(net, num_filters, padding=2, is_training=False):
spatial_pad = tf.constant( spatial_pad = tf.constant(
[[0, 0], [padding, padding], [padding, padding], [0, 0]], [[0, 0], [padding, padding], [padding, padding], [0, 0]],
dtype=tf.int32) dtype=tf.int32)
return tf.pad(net, spatial_pad, 'REFLECT') return tf.pad(net, spatial_pad, pad_mode)
else: else:
return net return net
...@@ -258,7 +261,7 @@ def pix2pix_discriminator(net, num_filters, padding=2, is_training=False): ...@@ -258,7 +261,7 @@ def pix2pix_discriminator(net, num_filters, padding=2, is_training=False):
kernel_size=[4, 4], kernel_size=[4, 4],
stride=2, stride=2,
padding='valid', padding='valid',
activation_fn=tf.nn.leaky_relu): activation_fn=activation_fn):
# No normalization on the input layer. # No normalization on the input layer.
net = layers.conv2d( net = layers.conv2d(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment