Commit bc71d8e9 authored by Dan Kondratyuk's avatar Dan Kondratyuk Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 383861528
parent cb9aeaf4
...@@ -44,6 +44,11 @@ class Movinet(hyperparams.Config): ...@@ -44,6 +44,11 @@ class Movinet(hyperparams.Config):
# 2plus1d: (2+1)D convolution with Conv2D (2D reshaping) # 2plus1d: (2+1)D convolution with Conv2D (2D reshaping)
# 3d_2plus1d: (2+1)D convolution with Conv3D (no 2D reshaping) # 3d_2plus1d: (2+1)D convolution with Conv3D (no 2D reshaping)
conv_type: str = '3d' conv_type: str = '3d'
# Choose from ['3d', '2d', '2plus3d']
# 3d: default 3D global average pooling.
# 2d: 2D global average pooling.
# 2plus3d: concatenation of 2D and 3D global average pooling.
se_type: str = '3d'
activation: str = 'swish' activation: str = 'swish'
gating_activation: str = 'sigmoid' gating_activation: str = 'sigmoid'
stochastic_depth_drop_rate: float = 0.2 stochastic_depth_drop_rate: float = 0.2
......
...@@ -53,6 +53,12 @@ flags.DEFINE_string( ...@@ -53,6 +53,12 @@ flags.DEFINE_string(
'3x3 followed by 5x1 conv). 3d_2plus1d uses (2+1)D convolution with ' '3x3 followed by 5x1 conv). 3d_2plus1d uses (2+1)D convolution with '
'Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 ' 'Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 '
'followed by 5x1x1 conv).') 'followed by 5x1x1 conv).')
flags.DEFINE_string(
'se_type', '3d',
'3d, 2d, or 2plus3d. 3d uses the default 3D spatiotemporal global average'
'pooling for squeeze excitation. 2d uses 2D spatial global average pooling '
'on each frame. 2plus3d concatenates both 3D and 2D global average '
'pooling.')
flags.DEFINE_string( flags.DEFINE_string(
'activation', 'swish', 'activation', 'swish',
'The main activation to use across layers.') 'The main activation to use across layers.')
...@@ -102,6 +108,7 @@ def main(_) -> None: ...@@ -102,6 +108,7 @@ def main(_) -> None:
input_specs=input_specs, input_specs=input_specs,
activation=FLAGS.activation, activation=FLAGS.activation,
gating_activation=FLAGS.gating_activation, gating_activation=FLAGS.gating_activation,
se_type=FLAGS.se_type,
use_positional_encoding=FLAGS.use_positional_encoding) use_positional_encoding=FLAGS.use_positional_encoding)
model = movinet_model.MovinetClassifier( model = movinet_model.MovinetClassifier(
backbone, backbone,
......
...@@ -307,6 +307,7 @@ class Movinet(tf.keras.Model): ...@@ -307,6 +307,7 @@ class Movinet(tf.keras.Model):
causal: bool = False, causal: bool = False,
use_positional_encoding: bool = False, use_positional_encoding: bool = False,
conv_type: str = '3d', conv_type: str = '3d',
se_type: str = '3d',
input_specs: Optional[tf.keras.layers.InputSpec] = None, input_specs: Optional[tf.keras.layers.InputSpec] = None,
activation: str = 'swish', activation: str = 'swish',
gating_activation: str = 'sigmoid', gating_activation: str = 'sigmoid',
...@@ -333,6 +334,10 @@ class Movinet(tf.keras.Model): ...@@ -333,6 +334,10 @@ class Movinet(tf.keras.Model):
3x3 followed by 5x1 conv). '3d_2plus1d' uses (2+1)D convolution with 3x3 followed by 5x1 conv). '3d_2plus1d' uses (2+1)D convolution with
Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 followed Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 followed
by 5x1x1 conv). by 5x1x1 conv).
se_type: '3d', '2d', or '2plus3d'. '3d' uses the default 3D
spatiotemporal global average pooling for squeeze excitation. '2d'
uses 2D spatial global average pooling on each frame. '2plus3d'
concatenates both 3D and 2D global average pooling.
input_specs: the model input spec to use. input_specs: the model input spec to use.
activation: name of the main activation function. activation: name of the main activation function.
gating_activation: gating activation to use in squeeze excitation layers. gating_activation: gating activation to use in squeeze excitation layers.
...@@ -356,12 +361,15 @@ class Movinet(tf.keras.Model): ...@@ -356,12 +361,15 @@ class Movinet(tf.keras.Model):
if conv_type not in ('3d', '2plus1d', '3d_2plus1d'): if conv_type not in ('3d', '2plus1d', '3d_2plus1d'):
raise ValueError('Unknown conv type: {}'.format(conv_type)) raise ValueError('Unknown conv type: {}'.format(conv_type))
if se_type not in ('3d', '2d', '2plus3d'):
raise ValueError('Unknown squeeze excitation type: {}'.format(se_type))
self._model_id = model_id self._model_id = model_id
self._block_specs = block_specs self._block_specs = block_specs
self._causal = causal self._causal = causal
self._use_positional_encoding = use_positional_encoding self._use_positional_encoding = use_positional_encoding
self._conv_type = conv_type self._conv_type = conv_type
self._se_type = se_type
self._input_specs = input_specs self._input_specs = input_specs
self._use_sync_bn = use_sync_bn self._use_sync_bn = use_sync_bn
self._activation = activation self._activation = activation
...@@ -481,8 +489,9 @@ class Movinet(tf.keras.Model): ...@@ -481,8 +489,9 @@ class Movinet(tf.keras.Model):
gating_activation=self._gating_activation, gating_activation=self._gating_activation,
stochastic_depth_drop_rate=stochastic_depth_drop_rate, stochastic_depth_drop_rate=stochastic_depth_drop_rate,
conv_type=self._conv_type, conv_type=self._conv_type,
use_positional_encoding=self._use_positional_encoding and se_type=self._se_type,
self._causal, use_positional_encoding=
self._use_positional_encoding and self._causal,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
batch_norm_layer=self._norm, batch_norm_layer=self._norm,
...@@ -695,6 +704,7 @@ def build_movinet( ...@@ -695,6 +704,7 @@ def build_movinet(
causal=backbone_cfg.causal, causal=backbone_cfg.causal,
use_positional_encoding=backbone_cfg.use_positional_encoding, use_positional_encoding=backbone_cfg.use_positional_encoding,
conv_type=backbone_cfg.conv_type, conv_type=backbone_cfg.conv_type,
se_type=backbone_cfg.se_type,
input_specs=input_specs, input_specs=input_specs,
activation=backbone_cfg.activation, activation=backbone_cfg.activation,
gating_activation=backbone_cfg.gating_activation, gating_activation=backbone_cfg.gating_activation,
......
...@@ -669,6 +669,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer): ...@@ -669,6 +669,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
def __init__( def __init__(
self, self,
hidden_filters: int, hidden_filters: int,
se_type: str = '3d',
activation: nn_layers.Activation = 'swish', activation: nn_layers.Activation = 'swish',
gating_activation: nn_layers.Activation = 'sigmoid', gating_activation: nn_layers.Activation = 'sigmoid',
causal: bool = False, causal: bool = False,
...@@ -683,6 +684,10 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer): ...@@ -683,6 +684,10 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
Args: Args:
hidden_filters: The hidden filters of squeeze excite. hidden_filters: The hidden filters of squeeze excite.
se_type: '3d', '2d', or '2plus3d'. '3d' uses the default 3D
spatiotemporal global average pooling for squeeze excitation. '2d'
uses 2D spatial global average pooling on each frame. '2plus3d'
concatenates both 3D and 2D global average pooling.
activation: name of the activation function. activation: name of the activation function.
gating_activation: name of the activation function for gating. gating_activation: name of the activation function for gating.
causal: if True, use causal mode in the global average pool. causal: if True, use causal mode in the global average pool.
...@@ -700,6 +705,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer): ...@@ -700,6 +705,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
super(StreamSqueezeExcitation, self).__init__(**kwargs) super(StreamSqueezeExcitation, self).__init__(**kwargs)
self._hidden_filters = hidden_filters self._hidden_filters = hidden_filters
self._se_type = se_type
self._activation = activation self._activation = activation
self._gating_activation = gating_activation self._gating_activation = gating_activation
self._causal = causal self._causal = causal
...@@ -709,8 +715,9 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer): ...@@ -709,8 +715,9 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
self._use_positional_encoding = use_positional_encoding self._use_positional_encoding = use_positional_encoding
self._state_prefix = state_prefix self._state_prefix = state_prefix
self._pool = nn_layers.GlobalAveragePool3D( self._spatiotemporal_pool = nn_layers.GlobalAveragePool3D(
keepdims=True, causal=causal, state_prefix=state_prefix) keepdims=True, causal=causal, state_prefix=state_prefix)
self._spatial_pool = nn_layers.SpatialAveragePool3D(keepdims=True)
self._pos_encoding = None self._pos_encoding = None
if use_positional_encoding: if use_positional_encoding:
...@@ -721,6 +728,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer): ...@@ -721,6 +728,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
"""Returns a dictionary containing the config used for initialization.""" """Returns a dictionary containing the config used for initialization."""
config = { config = {
'hidden_filters': self._hidden_filters, 'hidden_filters': self._hidden_filters,
'se_type': self._se_type,
'activation': self._activation, 'activation': self._activation,
'gating_activation': self._gating_activation, 'gating_activation': self._gating_activation,
'causal': self._causal, 'causal': self._causal,
...@@ -777,13 +785,28 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer): ...@@ -777,13 +785,28 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
""" """
states = dict(states) if states is not None else {} states = dict(states) if states is not None else {}
x, states = self._pool(inputs, states=states) if self._se_type == '3d':
x, states = self._spatiotemporal_pool(inputs, states=states)
elif self._se_type == '2d':
x = self._spatial_pool(inputs)
elif self._se_type == '2plus3d':
x_space = self._spatial_pool(inputs)
x, states = self._spatiotemporal_pool(x_space, states=states)
if not self._causal:
x = tf.tile(x, [1, tf.shape(inputs)[1], 1, 1, 1])
x = tf.concat([x, x_space], axis=-1)
else:
raise ValueError('Unknown Squeeze Excitation type {}'.format(
self._se_type))
if self._pos_encoding is not None: if self._pos_encoding is not None:
x, states = self._pos_encoding(x, states=states) x, states = self._pos_encoding(x, states=states)
x = self._se_reduce(x) x = self._se_reduce(x)
x = self._se_expand(x) x = self._se_expand(x)
return x * inputs, states return x * inputs, states
...@@ -1003,6 +1026,7 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -1003,6 +1026,7 @@ class MovinetBlock(tf.keras.layers.Layer):
se_ratio: float = 0.25, se_ratio: float = 0.25,
stochastic_depth_drop_rate: float = 0., stochastic_depth_drop_rate: float = 0.,
conv_type: str = '3d', conv_type: str = '3d',
se_type: str = '3d',
use_positional_encoding: bool = False, use_positional_encoding: bool = False,
kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal', kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras
...@@ -1029,6 +1053,10 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -1029,6 +1053,10 @@ class MovinetBlock(tf.keras.layers.Layer):
ops. '2plus1d' split any 3D ops into two sequential 2D ops with their ops. '2plus1d' split any 3D ops into two sequential 2D ops with their
own batch norm and activation. '3d_2plus1d' is like '2plus1d', but own batch norm and activation. '3d_2plus1d' is like '2plus1d', but
uses two sequential 3D ops instead. uses two sequential 3D ops instead.
se_type: '3d', '2d', or '2plus3d'. '3d' uses the default 3D
spatiotemporal global average pooling for squeeze excitation. '2d'
uses 2D spatial global average pooling on each frame. '2plus3d'
concatenates both 3D and 2D global average pooling.
use_positional_encoding: add a positional encoding after the (cumulative) use_positional_encoding: add a positional encoding after the (cumulative)
global average pooling layer in the squeeze excite layer. global average pooling layer in the squeeze excite layer.
kernel_initializer: kernel initializer for the conv operations. kernel_initializer: kernel initializer for the conv operations.
...@@ -1044,8 +1072,10 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -1044,8 +1072,10 @@ class MovinetBlock(tf.keras.layers.Layer):
self._kernel_size = normalize_tuple(kernel_size, 3, 'kernel_size') self._kernel_size = normalize_tuple(kernel_size, 3, 'kernel_size')
self._strides = normalize_tuple(strides, 3, 'strides') self._strides = normalize_tuple(strides, 3, 'strides')
# Use a multiplier of 2 if concatenating multiple features
se_multiplier = 2 if se_type == '2plus3d' else 1
se_hidden_filters = nn_layers.make_divisible( se_hidden_filters = nn_layers.make_divisible(
se_ratio * expand_filters, divisor=8) se_ratio * expand_filters * se_multiplier, divisor=8)
self._out_filters = out_filters self._out_filters = out_filters
self._expand_filters = expand_filters self._expand_filters = expand_filters
self._kernel_size = kernel_size self._kernel_size = kernel_size
...@@ -1056,6 +1086,7 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -1056,6 +1086,7 @@ class MovinetBlock(tf.keras.layers.Layer):
self._downsample = any(s > 1 for s in self._strides) self._downsample = any(s > 1 for s in self._strides)
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._conv_type = conv_type self._conv_type = conv_type
self._se_type = se_type
self._use_positional_encoding = use_positional_encoding self._use_positional_encoding = use_positional_encoding
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
...@@ -1106,6 +1137,7 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -1106,6 +1137,7 @@ class MovinetBlock(tf.keras.layers.Layer):
name='projection') name='projection')
self._attention = StreamSqueezeExcitation( self._attention = StreamSqueezeExcitation(
se_hidden_filters, se_hidden_filters,
se_type=se_type,
activation=activation, activation=activation,
gating_activation=gating_activation, gating_activation=gating_activation,
causal=self._causal, causal=self._causal,
...@@ -1129,6 +1161,7 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -1129,6 +1161,7 @@ class MovinetBlock(tf.keras.layers.Layer):
'se_ratio': self._se_ratio, 'se_ratio': self._se_ratio,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate, 'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'conv_type': self._conv_type, 'conv_type': self._conv_type,
'se_type': self._se_type,
'use_positional_encoding': self._use_positional_encoding, 'use_positional_encoding': self._use_positional_encoding,
'kernel_initializer': self._kernel_initializer, 'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer, 'kernel_regularizer': self._kernel_regularizer,
......
...@@ -314,6 +314,43 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -314,6 +314,43 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase):
[[4., 4., 4.]]]]], [[4., 4., 4.]]]]],
1e-5, 1e-5) 1e-5, 1e-5)
def test_stream_squeeze_excitation_2plus3d(self):
se = movinet_layers.StreamSqueezeExcitation(
3,
se_type='2plus3d',
causal=True,
activation='hard_swish',
gating_activation='hard_sigmoid',
kernel_initializer='ones')
inputs = tf.range(4, dtype=tf.float32) + 1.
inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
inputs = tf.tile(inputs, [1, 1, 2, 1, 3])
expected, _ = se(inputs)
for num_splits in [1, 2, 4]:
frames = tf.split(inputs, inputs.shape[1] // num_splits, axis=1)
states = {}
predicted = []
for frame in frames:
x, states = se(frame, states=states)
predicted.append(x)
predicted = tf.concat(predicted, axis=1)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
self.assertAllClose(
predicted,
[[[[[1., 1., 1.]],
[[1., 1., 1.]]],
[[[2., 2., 2.]],
[[2., 2., 2.]]],
[[[3., 3., 3.]],
[[3., 3., 3.]]],
[[[4., 4., 4.]],
[[4., 4., 4.]]]]])
def test_stream_movinet_block(self): def test_stream_movinet_block(self):
block = movinet_layers.MovinetBlock( block = movinet_layers.MovinetBlock(
out_filters=3, out_filters=3,
......
...@@ -131,6 +131,37 @@ class MovinetModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -131,6 +131,37 @@ class MovinetModelTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(predicted.shape, expected.shape) self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected, 1e-5, 1e-5) self.assertAllClose(predicted, expected, 1e-5, 1e-5)
def test_movinet_classifier_mobile(self):
"""Test if the model can run with mobile parameters."""
tf.keras.backend.set_image_data_format('channels_last')
backbone = movinet.Movinet(
model_id='a0',
causal=True,
use_external_states=True,
conv_type='2plus1d',
se_type='2plus3d',
activation='hard_swish',
gating_activation='hard_sigmoid'
)
model = movinet_model.MovinetClassifier(
backbone, num_classes=600, output_states=True)
inputs = tf.ones([1, 8, 172, 172, 3])
init_states = model.init_states(tf.shape(inputs))
expected, _ = model({**init_states, 'image': inputs})
frames = tf.split(inputs, inputs.shape[1], axis=1)
states = init_states
for frame in frames:
output, states = model({**states, 'image': frame})
predicted = output
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected, 1e-5, 1e-5)
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
"""Validate the classification network can be serialized and deserialized.""" """Validate the classification network can be serialized and deserialized."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment