Commit 76476cd9 authored by Jaehong Kim's avatar Jaehong Kim Committed by A. Unique TensorFlower
Browse files

Change activation/add to a keras layer form for the ResNet50, MobilenetV2 default case.

This CL changes ResidualBlock and InvertedBottleneckBlock.

PiperOrigin-RevId: 368383954
parent bf2d1354
......@@ -82,20 +82,36 @@ def is_special_none_tensor(tensor):
return tensor.shape.ndims == 0 and tensor.dtype == tf.int32
def get_activation(identifier):
def get_activation(identifier, use_keras_layer=False):
"""Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
It checks string first and if it is one of customized activation not in TF,
the corresponding activation will be returned. For non-customized activation
names and callable identifiers, always fallback to tf.keras.activations.get.
Prefers using keras layers when use_keras_layer=True. Now it only supports
'relu', 'linear', 'identity', 'swish'.
Args:
identifier: String name of the activation function or callable.
use_keras_layer: If True, use keras layer if identifier is allow-listed.
Returns:
A Python function corresponding to the activation function.
A Python function corresponding to the activation function or a keras
activation layer when use_keras_layer=True.
"""
if isinstance(identifier, six.string_types):
identifier = str(identifier).lower()
if use_keras_layer:
keras_layer_allowlist = {
"relu": "relu",
"linear": "linear",
"identity": "linear",
"swish": "swish",
"relu6": tf.nn.relu6,
}
if identifier in keras_layer_allowlist:
return tf.keras.layers.Activation(keras_layer_allowlist[identifier])
name_to_fn = {
"gelu": activations.gelu,
"simple_swish": activations.simple_swish,
......@@ -104,7 +120,6 @@ def get_activation(identifier):
"hard_sigmoid": activations.hard_sigmoid,
"identity": activations.identity,
}
identifier = str(identifier).lower()
if identifier in name_to_fn:
return tf.keras.activations.get(name_to_fn[identifier])
return tf.keras.activations.get(identifier)
......
......@@ -31,6 +31,7 @@ layers = tf.keras.layers
# pylint: disable=pointless-string-statement
@tf.keras.utils.register_keras_serializable(package='Vision')
class Conv2DBNBlock(tf.keras.layers.Layer):
"""A convolution block with batch normalization."""
......@@ -94,7 +95,6 @@ class Conv2DBNBlock(tf.keras.layers.Layer):
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def get_config(self):
config = {
......@@ -129,6 +129,8 @@ class Conv2DBNBlock(tf.keras.layers.Layer):
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._activation_layer = tf_utils.get_activation(
self._activation, use_keras_layer=True)
super(Conv2DBNBlock, self).build(input_shape)
......@@ -136,7 +138,7 @@ class Conv2DBNBlock(tf.keras.layers.Layer):
x = self._conv0(inputs)
if self._use_normalization:
x = self._norm0(x)
return self._activation_fn(x)
return self._activation_layer(x)
"""
Architecture: https://arxiv.org/abs/1704.04861.
......@@ -724,7 +726,7 @@ class MobileNet(tf.keras.Model):
raise ValueError('Unknown block type {} for layer {}'.format(
block_def.block_fn, i))
net = tf.identity(net, name=block_name)
net = tf.keras.layers.Activation('linear', name=block_name)(net)
if block_def.is_output:
endpoints[str(endpoint_level)] = net
......
......@@ -191,7 +191,7 @@ class ResNet(tf.keras.Model):
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
elif stem_type == 'v1':
x = layers.Conv2D(
filters=int(32 * self._depth_multiplier),
......@@ -206,7 +206,7 @@ class ResNet(tf.keras.Model):
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
x = layers.Conv2D(
filters=int(32 * self._depth_multiplier),
kernel_size=3,
......@@ -220,7 +220,7 @@ class ResNet(tf.keras.Model):
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
x = layers.Conv2D(
filters=int(64 * self._depth_multiplier),
kernel_size=3,
......@@ -234,7 +234,7 @@ class ResNet(tf.keras.Model):
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
else:
raise ValueError('Stem type {} not supported.'.format(stem_type))
......@@ -252,7 +252,7 @@ class ResNet(tf.keras.Model):
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
else:
x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
......@@ -338,7 +338,7 @@ class ResNet(tf.keras.Model):
norm_epsilon=self._norm_epsilon)(
x)
return tf.identity(x, name=name)
return tf.keras.layers.Activation('linear', name=name)(x)
def get_config(self):
config_dict = {
......
......@@ -303,7 +303,6 @@ class BottleneckBlock(tf.keras.layers.Layer):
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape):
if self._use_projection:
......@@ -345,6 +344,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._activation1 = tf_utils.get_activation(
self._activation, use_keras_layer=True)
self._conv2 = tf.keras.layers.Conv2D(
filters=self._filters,
......@@ -360,6 +361,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._activation2 = tf_utils.get_activation(
self._activation, use_keras_layer=True)
self._conv3 = tf.keras.layers.Conv2D(
filters=self._filters * 4,
......@@ -373,6 +376,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._activation3 = tf_utils.get_activation(
self._activation, use_keras_layer=True)
if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1:
self._squeeze_excitation = nn_layers.SqueezeExcitation(
......@@ -390,6 +395,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
self._stochastic_depth_drop_rate)
else:
self._stochastic_depth = None
self._add = tf.keras.layers.Add()
super(BottleneckBlock, self).build(input_shape)
......@@ -425,11 +431,11 @@ class BottleneckBlock(tf.keras.layers.Layer):
x = self._conv1(inputs)
x = self._norm1(x)
x = self._activation_fn(x)
x = self._activation1(x)
x = self._conv2(x)
x = self._norm2(x)
x = self._activation_fn(x)
x = self._activation2(x)
x = self._conv3(x)
x = self._norm3(x)
......@@ -440,7 +446,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
if self._stochastic_depth:
x = self._stochastic_depth(x, training=training)
return self._activation_fn(x + shortcut)
x = self._add([x, shortcut])
return self._activation3(x)
@tf.keras.utils.register_keras_serializable(package='Vision')
......@@ -549,11 +556,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
if not depthwise_activation:
self._depthwise_activation = activation
self._depthwise_activation_fn = tf_utils.get_activation(
self._depthwise_activation)
if regularize_depthwise:
self._depthsize_regularizer = kernel_regularizer
else:
......@@ -582,6 +586,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._activation_layer = tf_utils.get_activation(
self._activation, use_keras_layer=True)
if self._use_depthwise:
# Depthwise conv.
......@@ -599,6 +605,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._depthwise_activation_layer = tf_utils.get_activation(
self._depthwise_activation, use_keras_layer=True)
# Squeeze and excitation.
if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1:
......@@ -639,6 +647,7 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
self._stochastic_depth_drop_rate)
else:
self._stochastic_depth = None
self._add = tf.keras.layers.Add()
super(InvertedBottleneckBlock, self).build(input_shape)
......@@ -676,14 +685,14 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
if self._expand_ratio > 1:
x = self._conv0(inputs)
x = self._norm0(x)
x = self._activation_fn(x)
x = self._activation_layer(x)
else:
x = inputs
if self._use_depthwise:
x = self._conv1(x)
x = self._norm1(x)
x = self._depthwise_activation_fn(x)
x = self._depthwise_activation_layer(x)
if self._squeeze_excitation:
x = self._squeeze_excitation(x)
......@@ -696,7 +705,7 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
self._strides == 1):
if self._stochastic_depth:
x = self._stochastic_depth(x, training=training)
x = tf.add(x, shortcut)
x = self._add([x, shortcut])
return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment