Commit 97b97572 authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by A. Unique TensorFlower
Browse files

#movinet #vcc Add option global_average_pooling to Movinet. Allow model to...

#movinet #vcc Add option global_average_pooling to Movinet. Allow model to output time-dim preserved features.

PiperOrigin-RevId: 445205488
parent d2f47b86
...@@ -53,6 +53,7 @@ class Movinet(hyperparams.Config): ...@@ -53,6 +53,7 @@ class Movinet(hyperparams.Config):
gating_activation: str = 'sigmoid' gating_activation: str = 'sigmoid'
stochastic_depth_drop_rate: float = 0.2 stochastic_depth_drop_rate: float = 0.2
use_external_states: bool = False use_external_states: bool = False
average_pooling_type: str = '3d'
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -322,6 +322,7 @@ class Movinet(tf.keras.Model): ...@@ -322,6 +322,7 @@ class Movinet(tf.keras.Model):
stochastic_depth_drop_rate: float = 0., stochastic_depth_drop_rate: float = 0.,
use_external_states: bool = False, use_external_states: bool = False,
output_states: bool = True, output_states: bool = True,
average_pooling_type: str = '3d',
**kwargs): **kwargs):
"""MoViNet initialization function. """MoViNet initialization function.
...@@ -360,6 +361,8 @@ class Movinet(tf.keras.Model): ...@@ -360,6 +361,8 @@ class Movinet(tf.keras.Model):
the model in streaming mode. Inputting the output states of the the model in streaming mode. Inputting the output states of the
previous input clip with the current input clip will utilize a stream previous input clip with the current input clip will utilize a stream
buffer for streaming video. buffer for streaming video.
average_pooling_type: The average pooling type. Currently supporting
['3d', '2d', 'none'].
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
block_specs = BLOCK_SPECS[model_id] block_specs = BLOCK_SPECS[model_id]
...@@ -393,6 +396,7 @@ class Movinet(tf.keras.Model): ...@@ -393,6 +396,7 @@ class Movinet(tf.keras.Model):
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._use_external_states = use_external_states self._use_external_states = use_external_states
self._output_states = output_states self._output_states = output_states
self._average_pooling_type = average_pooling_type
if self._use_external_states and not self._causal: if self._use_external_states and not self._causal:
raise ValueError('External states should be used with causal mode.') raise ValueError('External states should be used with causal mode.')
...@@ -520,6 +524,7 @@ class Movinet(tf.keras.Model): ...@@ -520,6 +524,7 @@ class Movinet(tf.keras.Model):
batch_norm_layer=self._norm, batch_norm_layer=self._norm,
batch_norm_momentum=self._norm_momentum, batch_norm_momentum=self._norm_momentum,
batch_norm_epsilon=self._norm_epsilon, batch_norm_epsilon=self._norm_epsilon,
average_pooling_type=self._average_pooling_type,
state_prefix='state_head', state_prefix='state_head',
name='head') name='head')
x, states = layer_obj(x, states=states) x, states = layer_obj(x, states=states)
...@@ -730,4 +735,5 @@ def build_movinet( ...@@ -730,4 +735,5 @@ def build_movinet(
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer, kernel_regularizer=l2_regularizer,
stochastic_depth_drop_rate=backbone_cfg.stochastic_depth_drop_rate, stochastic_depth_drop_rate=backbone_cfg.stochastic_depth_drop_rate,
use_external_states=backbone_cfg.use_external_states) use_external_states=backbone_cfg.use_external_states,
average_pooling_type=backbone_cfg.average_pooling_type)
...@@ -802,12 +802,14 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer): ...@@ -802,12 +802,14 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
states = dict(states) if states is not None else {} states = dict(states) if states is not None else {}
if self._se_type == '3d': if self._se_type == '3d':
x, states = self._spatiotemporal_pool(inputs, states=states) x, states = self._spatiotemporal_pool(
inputs, states=states, output_states=True)
elif self._se_type == '2d': elif self._se_type == '2d':
x = self._spatial_pool(inputs) x = self._spatial_pool(inputs)
elif self._se_type == '2plus3d': elif self._se_type == '2plus3d':
x_space = self._spatial_pool(inputs) x_space = self._spatial_pool(inputs)
x, states = self._spatiotemporal_pool(x_space, states=states) x, states = self._spatiotemporal_pool(
x_space, states=states, output_states=True)
if not self._causal: if not self._causal:
x = tf.tile(x, [1, tf.shape(inputs)[1], 1, 1, 1]) x = tf.tile(x, [1, tf.shape(inputs)[1], 1, 1, 1])
...@@ -1362,6 +1364,7 @@ class Head(tf.keras.layers.Layer): ...@@ -1362,6 +1364,7 @@ class Head(tf.keras.layers.Layer):
tf.keras.layers.BatchNormalization, tf.keras.layers.BatchNormalization,
batch_norm_momentum: float = 0.99, batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3, batch_norm_epsilon: float = 1e-3,
average_pooling_type: str = '3d',
state_prefix: Optional[str] = None, # pytype: disable=annotation-type-mismatch # typed-keras state_prefix: Optional[str] = None, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs): **kwargs):
"""Implementation for video model head. """Implementation for video model head.
...@@ -1378,6 +1381,8 @@ class Head(tf.keras.layers.Layer): ...@@ -1378,6 +1381,8 @@ class Head(tf.keras.layers.Layer):
batch_norm_layer: class to use for batch norm. batch_norm_layer: class to use for batch norm.
batch_norm_momentum: momentum of the batch norm operation. batch_norm_momentum: momentum of the batch norm operation.
batch_norm_epsilon: epsilon of the batch norm operation. batch_norm_epsilon: epsilon of the batch norm operation.
average_pooling_type: The average pooling type. Currently supporting
['3d', '2d', 'none'].
state_prefix: a prefix string to identify states. state_prefix: a prefix string to identify states.
**kwargs: keyword arguments to be passed to this layer. **kwargs: keyword arguments to be passed to this layer.
""" """
...@@ -1404,8 +1409,16 @@ class Head(tf.keras.layers.Layer): ...@@ -1404,8 +1409,16 @@ class Head(tf.keras.layers.Layer):
batch_norm_momentum=self._batch_norm_momentum, batch_norm_momentum=self._batch_norm_momentum,
batch_norm_epsilon=self._batch_norm_epsilon, batch_norm_epsilon=self._batch_norm_epsilon,
name='project') name='project')
self._pool = nn_layers.GlobalAveragePool3D( if average_pooling_type.lower() == '3d':
keepdims=True, causal=False, state_prefix=state_prefix) self._pool = nn_layers.GlobalAveragePool3D(
keepdims=True, causal=False, state_prefix=state_prefix)
elif average_pooling_type.lower() == '2d':
self._pool = nn_layers.SpatialAveragePool3D(keepdims=True)
elif average_pooling_type == 'none':
self._pool = None
else:
raise ValueError(
'%s average_pooling_type is not supported.' % average_pooling_type)
def get_config(self): def get_config(self):
"""Returns a dictionary containing the config used for initialization.""" """Returns a dictionary containing the config used for initialization."""
...@@ -1439,7 +1452,11 @@ class Head(tf.keras.layers.Layer): ...@@ -1439,7 +1452,11 @@ class Head(tf.keras.layers.Layer):
""" """
states = dict(states) if states is not None else {} states = dict(states) if states is not None else {}
x = self._project(inputs) x = self._project(inputs)
return self._pool(x, states=states) if self._pool is not None:
outputs = self._pool(x, states=states, output_states=True)
else:
outputs = x
return outputs
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
......
...@@ -697,7 +697,7 @@ class GlobalAveragePool3D(tf.keras.layers.Layer): ...@@ -697,7 +697,7 @@ class GlobalAveragePool3D(tf.keras.layers.Layer):
def call(self, def call(self,
inputs: tf.Tensor, inputs: tf.Tensor,
states: Optional[States] = None, states: Optional[States] = None,
output_states: bool = True output_states: bool = False
) -> Union[tf.Tensor, Tuple[tf.Tensor, States]]: ) -> Union[tf.Tensor, Tuple[tf.Tensor, States]]:
"""Calls the layer with the given inputs. """Calls the layer with the given inputs.
...@@ -813,13 +813,14 @@ class SpatialAveragePool3D(tf.keras.layers.Layer): ...@@ -813,13 +813,14 @@ class SpatialAveragePool3D(tf.keras.layers.Layer):
super(SpatialAveragePool3D, self).build(input_shape) super(SpatialAveragePool3D, self).build(input_shape)
def call(self, inputs): def call(self, inputs, states=None, output_states: bool = False):
"""Calls the layer with the given inputs.""" """Calls the layer with the given inputs."""
if inputs.shape.rank != 5: if inputs.shape.rank != 5:
raise ValueError( raise ValueError(
'Input should have rank {}, got {}'.format(5, inputs.shape.rank)) 'Input should have rank {}, got {}'.format(5, inputs.shape.rank))
return tf.reduce_mean(inputs, axis=(2, 3), keepdims=self._keepdims) output = tf.reduce_mean(inputs, axis=(2, 3), keepdims=self._keepdims)
return (output, states) if output_states else output
class CausalConvMixin: class CausalConvMixin:
......
...@@ -134,14 +134,14 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -134,14 +134,14 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
inputs = tf.range(4, dtype=tf.float32) + 1. inputs = tf.range(4, dtype=tf.float32) + 1.
inputs = tf.reshape(inputs, [1, 4, 1, 1, 1]) inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
inputs = tf.tile(inputs, [1, 1, 2, 2, 3]) inputs = tf.tile(inputs, [1, 1, 2, 2, 3])
expected, _ = gap(inputs) expected, _ = gap(inputs, output_states=True)
for num_splits in [1, 2, 4]: for num_splits in [1, 2, 4]:
frames = tf.split(inputs, num_splits, axis=1) frames = tf.split(inputs, num_splits, axis=1)
states = {} states = {}
predicted = None predicted = None
for frame in frames: for frame in frames:
predicted, states = gap(frame, states=states) predicted, states = gap(frame, states=states, output_states=True)
self.assertEqual(predicted.shape, expected.shape) self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected) self.assertAllClose(predicted, expected)
...@@ -155,14 +155,14 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -155,14 +155,14 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
inputs = tf.range(4, dtype=tf.float32) + 1. inputs = tf.range(4, dtype=tf.float32) + 1.
inputs = tf.reshape(inputs, [1, 4, 1, 1, 1]) inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
inputs = tf.tile(inputs, [1, 1, 2, 2, 3]) inputs = tf.tile(inputs, [1, 1, 2, 2, 3])
expected, _ = gap(inputs) expected, _ = gap(inputs, output_states=True)
for num_splits in [1, 2, 4]: for num_splits in [1, 2, 4]:
frames = tf.split(inputs, num_splits, axis=1) frames = tf.split(inputs, num_splits, axis=1)
states = {} states = {}
predicted = [] predicted = []
for frame in frames: for frame in frames:
x, states = gap(frame, states=states) x, states = gap(frame, states=states, output_states=True)
predicted.append(x) predicted.append(x)
predicted = tf.concat(predicted, axis=1) predicted = tf.concat(predicted, axis=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment