Commit 4fa92552 authored by Dan Kondratyuk's avatar Dan Kondratyuk Committed by A. Unique TensorFlower
Browse files

Add activation function parameters to backbone.

PiperOrigin-RevId: 381908843
parent f649db2c
...@@ -68,6 +68,23 @@ def round_filters(filters: int, ...@@ -68,6 +68,23 @@ def round_filters(filters: int,
return int(new_filters) return int(new_filters)
def hard_swish(x: tf.Tensor) -> tf.Tensor:
"""A Swish6/H-Swish activation function.
Reference: Section 5.2 of Howard et al. "Searching for MobileNet V3."
https://arxiv.org/pdf/1905.02244.pdf
Args:
x: the input tensor.
Returns:
The activation output.
"""
return x * tf.nn.relu6(x + 3.) * (1. / 6.)
tf.keras.utils.get_custom_objects().update({'hard_swish': hard_swish})
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class SqueezeExcitation(tf.keras.layers.Layer): class SqueezeExcitation(tf.keras.layers.Layer):
"""Creates a squeeze and excitation layer.""" """Creates a squeeze and excitation layer."""
......
...@@ -24,6 +24,11 @@ from official.vision.beta.modeling.layers import nn_layers ...@@ -24,6 +24,11 @@ from official.vision.beta.modeling.layers import nn_layers
class NNLayersTest(parameterized.TestCase, tf.test.TestCase): class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
def test_hard_swish(self):
activation = tf.keras.layers.Activation('hard_swish')
output = activation(tf.constant([-3, -1.5, 0, 3]))
self.assertAllEqual(output, [0., -0.375, 0., 3.])
def test_scale(self): def test_scale(self):
scale = nn_layers.Scale(initializer=tf.keras.initializers.constant(10.)) scale = nn_layers.Scale(initializer=tf.keras.initializers.constant(10.))
output = scale(3.) output = scale(3.)
......
...@@ -44,6 +44,8 @@ class Movinet(hyperparams.Config): ...@@ -44,6 +44,8 @@ class Movinet(hyperparams.Config):
# 2plus1d: (2+1)D convolution with Conv2D (2D reshaping) # 2plus1d: (2+1)D convolution with Conv2D (2D reshaping)
# 3d_2plus1d: (2+1)D convolution with Conv3D (no 2D reshaping) # 3d_2plus1d: (2+1)D convolution with Conv3D (no 2D reshaping)
conv_type: str = '3d' conv_type: str = '3d'
activation: str = 'swish'
gating_activation: str = 'sigmoid'
stochastic_depth_drop_rate: float = 0.2 stochastic_depth_drop_rate: float = 0.2
use_external_states: bool = False use_external_states: bool = False
......
...@@ -53,6 +53,12 @@ flags.DEFINE_string( ...@@ -53,6 +53,12 @@ flags.DEFINE_string(
'3x3 followed by 5x1 conv). 3d_2plus1d uses (2+1)D convolution with ' '3x3 followed by 5x1 conv). 3d_2plus1d uses (2+1)D convolution with '
'Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 ' 'Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 '
'followed by 5x1x1 conv).') 'followed by 5x1x1 conv).')
flags.DEFINE_string(
'activation', 'swish',
'The main activation to use across layers.')
flags.DEFINE_string(
'gating_activation', 'sigmoid',
'The gating activation to use in squeeze-excitation layers.')
flags.DEFINE_bool( flags.DEFINE_bool(
'use_positional_encoding', False, 'use_positional_encoding', False,
'Whether to use positional encoding (only applied when causal=True).') 'Whether to use positional encoding (only applied when causal=True).')
...@@ -94,6 +100,8 @@ def main(_) -> None: ...@@ -94,6 +100,8 @@ def main(_) -> None:
conv_type=FLAGS.conv_type, conv_type=FLAGS.conv_type,
use_external_states=FLAGS.causal, use_external_states=FLAGS.causal,
input_specs=input_specs, input_specs=input_specs,
activation=FLAGS.activation,
gating_activation=FLAGS.gating_activation,
use_positional_encoding=FLAGS.use_positional_encoding) use_positional_encoding=FLAGS.use_positional_encoding)
model = movinet_model.MovinetClassifier( model = movinet_model.MovinetClassifier(
backbone, backbone,
......
...@@ -309,6 +309,7 @@ class Movinet(tf.keras.Model): ...@@ -309,6 +309,7 @@ class Movinet(tf.keras.Model):
conv_type: str = '3d', conv_type: str = '3d',
input_specs: Optional[tf.keras.layers.InputSpec] = None, input_specs: Optional[tf.keras.layers.InputSpec] = None,
activation: str = 'swish', activation: str = 'swish',
gating_activation: str = 'sigmoid',
use_sync_bn: bool = True, use_sync_bn: bool = True,
norm_momentum: float = 0.99, norm_momentum: float = 0.99,
norm_epsilon: float = 0.001, norm_epsilon: float = 0.001,
...@@ -333,7 +334,8 @@ class Movinet(tf.keras.Model): ...@@ -333,7 +334,8 @@ class Movinet(tf.keras.Model):
Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 followed Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 followed
by 5x1x1 conv). by 5x1x1 conv).
input_specs: the model input spec to use. input_specs: the model input spec to use.
activation: name of the activation function. activation: name of the main activation function.
gating_activation: gating activation to use in squeeze excitation layers.
use_sync_bn: if True, use synchronized batch normalization. use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: normalization momentum for the moving average. norm_momentum: normalization momentum for the moving average.
norm_epsilon: small float added to variance to avoid dividing by norm_epsilon: small float added to variance to avoid dividing by
...@@ -363,6 +365,7 @@ class Movinet(tf.keras.Model): ...@@ -363,6 +365,7 @@ class Movinet(tf.keras.Model):
self._input_specs = input_specs self._input_specs = input_specs
self._use_sync_bn = use_sync_bn self._use_sync_bn = use_sync_bn
self._activation = activation self._activation = activation
self._gating_activation = gating_activation
self._norm_momentum = norm_momentum self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
if use_sync_bn: if use_sync_bn:
...@@ -475,6 +478,7 @@ class Movinet(tf.keras.Model): ...@@ -475,6 +478,7 @@ class Movinet(tf.keras.Model):
strides=strides, strides=strides,
causal=self._causal, causal=self._causal,
activation=self._activation, activation=self._activation,
gating_activation=self._gating_activation,
stochastic_depth_drop_rate=stochastic_depth_drop_rate, stochastic_depth_drop_rate=stochastic_depth_drop_rate,
conv_type=self._conv_type, conv_type=self._conv_type,
use_positional_encoding=self._use_positional_encoding and use_positional_encoding=self._use_positional_encoding and
...@@ -692,7 +696,8 @@ def build_movinet( ...@@ -692,7 +696,8 @@ def build_movinet(
use_positional_encoding=backbone_cfg.use_positional_encoding, use_positional_encoding=backbone_cfg.use_positional_encoding,
conv_type=backbone_cfg.conv_type, conv_type=backbone_cfg.conv_type,
input_specs=input_specs, input_specs=input_specs,
activation=norm_activation_config.activation, activation=backbone_cfg.activation,
gating_activation=backbone_cfg.gating_activation,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
......
...@@ -999,6 +999,7 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -999,6 +999,7 @@ class MovinetBlock(tf.keras.layers.Layer):
strides: Union[int, Sequence[int]] = (1, 1, 1), strides: Union[int, Sequence[int]] = (1, 1, 1),
causal: bool = False, causal: bool = False,
activation: nn_layers.Activation = 'swish', activation: nn_layers.Activation = 'swish',
gating_activation: nn_layers.Activation = 'sigmoid',
se_ratio: float = 0.25, se_ratio: float = 0.25,
stochastic_depth_drop_rate: float = 0., stochastic_depth_drop_rate: float = 0.,
conv_type: str = '3d', conv_type: str = '3d',
...@@ -1021,6 +1022,7 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -1021,6 +1022,7 @@ class MovinetBlock(tf.keras.layers.Layer):
strides: strides of the main depthwise convolution. strides: strides of the main depthwise convolution.
causal: if True, run the temporal convolutions in causal mode. causal: if True, run the temporal convolutions in causal mode.
activation: activation to use across all conv operations. activation: activation to use across all conv operations.
gating_activation: gating activation to use in squeeze excitation layers.
se_ratio: squeeze excite filters ratio. se_ratio: squeeze excite filters ratio.
stochastic_depth_drop_rate: optional drop rate for stochastic depth. stochastic_depth_drop_rate: optional drop rate for stochastic depth.
conv_type: '3d', '2plus1d', or '3d_2plus1d'. '3d' uses the default 3D conv_type: '3d', '2plus1d', or '3d_2plus1d'. '3d' uses the default 3D
...@@ -1049,6 +1051,7 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -1049,6 +1051,7 @@ class MovinetBlock(tf.keras.layers.Layer):
self._kernel_size = kernel_size self._kernel_size = kernel_size
self._causal = causal self._causal = causal
self._activation = activation self._activation = activation
self._gating_activation = gating_activation
self._se_ratio = se_ratio self._se_ratio = se_ratio
self._downsample = any(s > 1 for s in self._strides) self._downsample = any(s > 1 for s in self._strides)
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
...@@ -1104,6 +1107,7 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -1104,6 +1107,7 @@ class MovinetBlock(tf.keras.layers.Layer):
self._attention = StreamSqueezeExcitation( self._attention = StreamSqueezeExcitation(
se_hidden_filters, se_hidden_filters,
activation=activation, activation=activation,
gating_activation=gating_activation,
causal=self._causal, causal=self._causal,
conv_type=conv_type, conv_type=conv_type,
use_positional_encoding=use_positional_encoding, use_positional_encoding=use_positional_encoding,
...@@ -1121,6 +1125,7 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -1121,6 +1125,7 @@ class MovinetBlock(tf.keras.layers.Layer):
'strides': self._strides, 'strides': self._strides,
'causal': self._causal, 'causal': self._causal,
'activation': self._activation, 'activation': self._activation,
'gating_activation': self._gating_activation,
'se_ratio': self._se_ratio, 'se_ratio': self._se_ratio,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate, 'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'conv_type': self._conv_type, 'conv_type': self._conv_type,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment