Commit 37c3e5ff authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by A. Unique TensorFlower
Browse files

#movinet Support 'none' squeeze and excitation layers in Movinet.

PiperOrigin-RevId: 424743840
parent 55008ba3
...@@ -338,7 +338,7 @@ class Movinet(tf.keras.Model): ...@@ -338,7 +338,7 @@ class Movinet(tf.keras.Model):
3x3 followed by 5x1 conv). '3d_2plus1d' uses (2+1)D convolution with 3x3 followed by 5x1 conv). '3d_2plus1d' uses (2+1)D convolution with
Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 followed Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 followed
by 5x1x1 conv). by 5x1x1 conv).
se_type: '3d', '2d', or '2plus3d'. '3d' uses the default 3D se_type: '3d', '2d', '2plus3d' or 'none'. '3d' uses the default 3D
spatiotemporal global average pooling for squeeze excitation. '2d' spatiotemporal global average pooling for squeeze excitation. '2d'
uses 2D spatial global average pooling on each frame. '2plus3d' uses 2D spatial global average pooling on each frame. '2plus3d'
concatenates both 3D and 2D global average pooling. concatenates both 3D and 2D global average pooling.
...@@ -369,7 +369,7 @@ class Movinet(tf.keras.Model): ...@@ -369,7 +369,7 @@ class Movinet(tf.keras.Model):
if conv_type not in ('3d', '2plus1d', '3d_2plus1d'): if conv_type not in ('3d', '2plus1d', '3d_2plus1d'):
raise ValueError('Unknown conv type: {}'.format(conv_type)) raise ValueError('Unknown conv type: {}'.format(conv_type))
if se_type not in ('3d', '2d', '2plus3d'): if se_type not in ('3d', '2d', '2plus3d', 'none'):
raise ValueError('Unknown squeeze excitation type: {}'.format(se_type)) raise ValueError('Unknown squeeze excitation type: {}'.format(se_type))
self._model_id = model_id self._model_id = model_id
...@@ -602,10 +602,11 @@ class Movinet(tf.keras.Model): ...@@ -602,10 +602,11 @@ class Movinet(tf.keras.Model):
expand_filters, expand_filters,
) )
states[f'{prefix}_pool_buffer'] = ( if '3d' in self._se_type:
input_shape[0], 1, 1, 1, expand_filters, states[f'{prefix}_pool_buffer'] = (
) input_shape[0], 1, 1, 1, expand_filters,
states[f'{prefix}_pool_frame_count'] = (1,) )
states[f'{prefix}_pool_frame_count'] = (1,)
if use_positional_encoding: if use_positional_encoding:
name = f'{prefix}_pos_enc_frame_count' name = f'{prefix}_pos_enc_frame_count'
......
...@@ -885,7 +885,8 @@ class MobileBottleneck(tf.keras.layers.Layer): ...@@ -885,7 +885,8 @@ class MobileBottleneck(tf.keras.layers.Layer):
x = self._expansion_layer(inputs) x = self._expansion_layer(inputs)
x, states = self._feature_layer(x, states=states) x, states = self._feature_layer(x, states=states)
x, states = self._attention_layer(x, states=states) if self._attention_layer is not None:
x, states = self._attention_layer(x, states=states)
x = self._projection_layer(x) x = self._projection_layer(x)
# Add identity so that the ops are ordered as written. This is useful for, # Add identity so that the ops are ordered as written. This is useful for,
...@@ -1136,18 +1137,20 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -1136,18 +1137,20 @@ class MovinetBlock(tf.keras.layers.Layer):
batch_norm_momentum=self._batch_norm_momentum, batch_norm_momentum=self._batch_norm_momentum,
batch_norm_epsilon=self._batch_norm_epsilon, batch_norm_epsilon=self._batch_norm_epsilon,
name='projection') name='projection')
self._attention = StreamSqueezeExcitation( self._attention = None
se_hidden_filters, if se_type != 'none':
se_type=se_type, self._attention = StreamSqueezeExcitation(
activation=activation, se_hidden_filters,
gating_activation=gating_activation, se_type=se_type,
causal=self._causal, activation=activation,
conv_type=conv_type, gating_activation=gating_activation,
use_positional_encoding=use_positional_encoding, causal=self._causal,
kernel_initializer=kernel_initializer, conv_type=conv_type,
kernel_regularizer=kernel_regularizer, use_positional_encoding=use_positional_encoding,
state_prefix=state_prefix, kernel_initializer=kernel_initializer,
name='se') kernel_regularizer=kernel_regularizer,
state_prefix=state_prefix,
name='se')
def get_config(self): def get_config(self):
"""Returns a dictionary containing the config used for initialization.""" """Returns a dictionary containing the config used for initialization."""
......
...@@ -378,6 +378,35 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -378,6 +378,35 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(predicted.shape, expected.shape) self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected) self.assertAllClose(predicted, expected)
def test_stream_movinet_block_none_se(self):
block = movinet_layers.MovinetBlock(
out_filters=3,
expand_filters=6,
kernel_size=(3, 3, 3),
strides=(1, 2, 2),
causal=True,
se_type='none',
state_prefix='test',
)
inputs = tf.range(4, dtype=tf.float32) + 1.
inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
inputs = tf.tile(inputs, [1, 1, 2, 1, 3])
expected, expected_states = block(inputs)
for num_splits in [1, 2, 4]:
frames = tf.split(inputs, inputs.shape[1] // num_splits, axis=1)
states = {}
predicted = []
for frame in frames:
x, states = block(frame, states=states)
predicted.append(x)
predicted = tf.concat(predicted, axis=1)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
self.assertAllEqual(list(expected_states.keys()), ['test_stream_buffer'])
def test_stream_classifier_head(self): def test_stream_classifier_head(self):
head = movinet_layers.Head(project_filters=5) head = movinet_layers.Head(project_filters=5)
classifier_head = movinet_layers.ClassifierHead( classifier_head = movinet_layers.ClassifierHead(
......
...@@ -99,6 +99,49 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -99,6 +99,49 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(predicted.shape, expected.shape) self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected, 1e-5, 1e-5) self.assertAllClose(predicted, expected, 1e-5, 1e-5)
def test_movinet_stream_nse(self):
"""Test if the backbone can be run in streaming mode w/o SE layer."""
tf.keras.backend.set_image_data_format('channels_last')
backbone = movinet.Movinet(
model_id='a0',
causal=True,
use_external_states=True,
se_type='none',
)
inputs = tf.ones([1, 5, 128, 128, 3])
init_states = backbone.init_states(tf.shape(inputs))
expected_endpoints, _ = backbone({**init_states, 'image': inputs})
frames = tf.split(inputs, inputs.shape[1], axis=1)
states = init_states
for frame in frames:
output, states = backbone({**states, 'image': frame})
predicted_endpoints = output
predicted = predicted_endpoints['head']
# The expected final output is simply the mean across frames
expected = expected_endpoints['head']
expected = tf.reduce_mean(expected, 1, keepdims=True)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected, 1e-5, 1e-5)
# Check contents in the states dictionary.
state_keys = list(init_states.keys())
self.assertIn('state_head_pool_buffer', state_keys)
self.assertIn('state_head_pool_frame_count', state_keys)
state_keys.remove('state_head_pool_buffer')
state_keys.remove('state_head_pool_frame_count')
# From now on, there are only 'stream_buffer' for the convolutions.
for state_key in state_keys:
self.assertIn(
'stream_buffer', state_key,
msg=f'Expecting stream_buffer only, found {state_key}')
def test_movinet_2plus1d_stream(self): def test_movinet_2plus1d_stream(self):
tf.keras.backend.set_image_data_format('channels_last') tf.keras.backend.set_image_data_format('channels_last')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment