Commit 4fa92552 authored by Dan Kondratyuk's avatar Dan Kondratyuk Committed by A. Unique TensorFlower
Browse files

Add activation function parameters to backbone.

PiperOrigin-RevId: 381908843
parent f649db2c
......@@ -68,6 +68,23 @@ def round_filters(filters: int,
return int(new_filters)
def hard_swish(x: tf.Tensor) -> tf.Tensor:
"""A Swish6/H-Swish activation function.
Reference: Section 5.2 of Howard et al. "Searching for MobileNet V3."
https://arxiv.org/pdf/1905.02244.pdf
Args:
x: the input tensor.
Returns:
The activation output.
"""
return x * tf.nn.relu6(x + 3.) * (1. / 6.)
tf.keras.utils.get_custom_objects().update({'hard_swish': hard_swish})
@tf.keras.utils.register_keras_serializable(package='Vision')
class SqueezeExcitation(tf.keras.layers.Layer):
"""Creates a squeeze and excitation layer."""
......
......@@ -24,6 +24,11 @@ from official.vision.beta.modeling.layers import nn_layers
class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
def test_hard_swish(self):
activation = tf.keras.layers.Activation('hard_swish')
output = activation(tf.constant([-3, -1.5, 0, 3]))
self.assertAllEqual(output, [0., -0.375, 0., 3.])
def test_scale(self):
scale = nn_layers.Scale(initializer=tf.keras.initializers.constant(10.))
output = scale(3.)
......
......@@ -44,6 +44,8 @@ class Movinet(hyperparams.Config):
# 2plus1d: (2+1)D convolution with Conv2D (2D reshaping)
# 3d_2plus1d: (2+1)D convolution with Conv3D (no 2D reshaping)
conv_type: str = '3d'
activation: str = 'swish'
gating_activation: str = 'sigmoid'
stochastic_depth_drop_rate: float = 0.2
use_external_states: bool = False
......
......@@ -53,6 +53,12 @@ flags.DEFINE_string(
'3x3 followed by 5x1 conv). 3d_2plus1d uses (2+1)D convolution with '
'Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 '
'followed by 5x1x1 conv).')
flags.DEFINE_string(
'activation', 'swish',
'The main activation to use across layers.')
flags.DEFINE_string(
'gating_activation', 'sigmoid',
'The gating activation to use in squeeze-excitation layers.')
flags.DEFINE_bool(
'use_positional_encoding', False,
'Whether to use positional encoding (only applied when causal=True).')
......@@ -94,6 +100,8 @@ def main(_) -> None:
conv_type=FLAGS.conv_type,
use_external_states=FLAGS.causal,
input_specs=input_specs,
activation=FLAGS.activation,
gating_activation=FLAGS.gating_activation,
use_positional_encoding=FLAGS.use_positional_encoding)
model = movinet_model.MovinetClassifier(
backbone,
......
......@@ -309,6 +309,7 @@ class Movinet(tf.keras.Model):
conv_type: str = '3d',
input_specs: Optional[tf.keras.layers.InputSpec] = None,
activation: str = 'swish',
gating_activation: str = 'sigmoid',
use_sync_bn: bool = True,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
......@@ -333,7 +334,8 @@ class Movinet(tf.keras.Model):
Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 followed
by 5x1x1 conv).
input_specs: the model input spec to use.
activation: name of the activation function.
activation: name of the main activation function.
gating_activation: gating activation to use in squeeze excitation layers.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: normalization momentum for the moving average.
norm_epsilon: small float added to variance to avoid dividing by
......@@ -363,6 +365,7 @@ class Movinet(tf.keras.Model):
self._input_specs = input_specs
self._use_sync_bn = use_sync_bn
self._activation = activation
self._gating_activation = gating_activation
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
if use_sync_bn:
......@@ -475,6 +478,7 @@ class Movinet(tf.keras.Model):
strides=strides,
causal=self._causal,
activation=self._activation,
gating_activation=self._gating_activation,
stochastic_depth_drop_rate=stochastic_depth_drop_rate,
conv_type=self._conv_type,
use_positional_encoding=self._use_positional_encoding and
......@@ -692,7 +696,8 @@ def build_movinet(
use_positional_encoding=backbone_cfg.use_positional_encoding,
conv_type=backbone_cfg.conv_type,
input_specs=input_specs,
activation=norm_activation_config.activation,
activation=backbone_cfg.activation,
gating_activation=backbone_cfg.gating_activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
......
......@@ -999,6 +999,7 @@ class MovinetBlock(tf.keras.layers.Layer):
strides: Union[int, Sequence[int]] = (1, 1, 1),
causal: bool = False,
activation: nn_layers.Activation = 'swish',
gating_activation: nn_layers.Activation = 'sigmoid',
se_ratio: float = 0.25,
stochastic_depth_drop_rate: float = 0.,
conv_type: str = '3d',
......@@ -1021,6 +1022,7 @@ class MovinetBlock(tf.keras.layers.Layer):
strides: strides of the main depthwise convolution.
causal: if True, run the temporal convolutions in causal mode.
activation: activation to use across all conv operations.
gating_activation: gating activation to use in squeeze excitation layers.
se_ratio: squeeze excite filters ratio.
stochastic_depth_drop_rate: optional drop rate for stochastic depth.
conv_type: '3d', '2plus1d', or '3d_2plus1d'. '3d' uses the default 3D
......@@ -1049,6 +1051,7 @@ class MovinetBlock(tf.keras.layers.Layer):
self._kernel_size = kernel_size
self._causal = causal
self._activation = activation
self._gating_activation = gating_activation
self._se_ratio = se_ratio
self._downsample = any(s > 1 for s in self._strides)
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
......@@ -1104,6 +1107,7 @@ class MovinetBlock(tf.keras.layers.Layer):
self._attention = StreamSqueezeExcitation(
se_hidden_filters,
activation=activation,
gating_activation=gating_activation,
causal=self._causal,
conv_type=conv_type,
use_positional_encoding=use_positional_encoding,
......@@ -1121,6 +1125,7 @@ class MovinetBlock(tf.keras.layers.Layer):
'strides': self._strides,
'causal': self._causal,
'activation': self._activation,
'gating_activation': self._gating_activation,
'se_ratio': self._se_ratio,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'conv_type': self._conv_type,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment