Commit 97b97572 authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by A. Unique TensorFlower
Browse files

#movinet #vcc Add option global_average_pooling to Movinet. Allow model to...

#movinet #vcc Add option global_average_pooling to Movinet. Allow model to output time-dim preserved features.

PiperOrigin-RevId: 445205488
parent d2f47b86
......@@ -53,6 +53,7 @@ class Movinet(hyperparams.Config):
gating_activation: str = 'sigmoid'
stochastic_depth_drop_rate: float = 0.2
use_external_states: bool = False
average_pooling_type: str = '3d'
@dataclasses.dataclass
......
......@@ -322,6 +322,7 @@ class Movinet(tf.keras.Model):
stochastic_depth_drop_rate: float = 0.,
use_external_states: bool = False,
output_states: bool = True,
average_pooling_type: str = '3d',
**kwargs):
"""MoViNet initialization function.
......@@ -360,6 +361,8 @@ class Movinet(tf.keras.Model):
the model in streaming mode. Inputting the output states of the
previous input clip with the current input clip will utilize a stream
buffer for streaming video.
average_pooling_type: The average pooling type. Currently supporting
['3d', '2d', 'none'].
**kwargs: keyword arguments to be passed.
"""
block_specs = BLOCK_SPECS[model_id]
......@@ -393,6 +396,7 @@ class Movinet(tf.keras.Model):
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._use_external_states = use_external_states
self._output_states = output_states
self._average_pooling_type = average_pooling_type
if self._use_external_states and not self._causal:
raise ValueError('External states should be used with causal mode.')
......@@ -520,6 +524,7 @@ class Movinet(tf.keras.Model):
batch_norm_layer=self._norm,
batch_norm_momentum=self._norm_momentum,
batch_norm_epsilon=self._norm_epsilon,
average_pooling_type=self._average_pooling_type,
state_prefix='state_head',
name='head')
x, states = layer_obj(x, states=states)
......@@ -730,4 +735,5 @@ def build_movinet(
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer,
stochastic_depth_drop_rate=backbone_cfg.stochastic_depth_drop_rate,
use_external_states=backbone_cfg.use_external_states)
use_external_states=backbone_cfg.use_external_states,
average_pooling_type=backbone_cfg.average_pooling_type)
......@@ -802,12 +802,14 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
states = dict(states) if states is not None else {}
if self._se_type == '3d':
x, states = self._spatiotemporal_pool(inputs, states=states)
x, states = self._spatiotemporal_pool(
inputs, states=states, output_states=True)
elif self._se_type == '2d':
x = self._spatial_pool(inputs)
elif self._se_type == '2plus3d':
x_space = self._spatial_pool(inputs)
x, states = self._spatiotemporal_pool(x_space, states=states)
x, states = self._spatiotemporal_pool(
x_space, states=states, output_states=True)
if not self._causal:
x = tf.tile(x, [1, tf.shape(inputs)[1], 1, 1, 1])
......@@ -1362,6 +1364,7 @@ class Head(tf.keras.layers.Layer):
tf.keras.layers.BatchNormalization,
batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3,
average_pooling_type: str = '3d',
state_prefix: Optional[str] = None, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs):
"""Implementation for video model head.
......@@ -1378,6 +1381,8 @@ class Head(tf.keras.layers.Layer):
batch_norm_layer: class to use for batch norm.
batch_norm_momentum: momentum of the batch norm operation.
batch_norm_epsilon: epsilon of the batch norm operation.
average_pooling_type: The average pooling type. Currently supporting
['3d', '2d', 'none'].
state_prefix: a prefix string to identify states.
**kwargs: keyword arguments to be passed to this layer.
"""
......@@ -1404,8 +1409,16 @@ class Head(tf.keras.layers.Layer):
batch_norm_momentum=self._batch_norm_momentum,
batch_norm_epsilon=self._batch_norm_epsilon,
name='project')
self._pool = nn_layers.GlobalAveragePool3D(
keepdims=True, causal=False, state_prefix=state_prefix)
if average_pooling_type.lower() == '3d':
self._pool = nn_layers.GlobalAveragePool3D(
keepdims=True, causal=False, state_prefix=state_prefix)
elif average_pooling_type.lower() == '2d':
self._pool = nn_layers.SpatialAveragePool3D(keepdims=True)
elif average_pooling_type == 'none':
self._pool = None
else:
raise ValueError(
'%s average_pooling_type is not supported.' % average_pooling_type)
def get_config(self):
"""Returns a dictionary containing the config used for initialization."""
......@@ -1439,7 +1452,11 @@ class Head(tf.keras.layers.Layer):
"""
states = dict(states) if states is not None else {}
x = self._project(inputs)
return self._pool(x, states=states)
if self._pool is not None:
outputs = self._pool(x, states=states, output_states=True)
else:
outputs = x
return outputs
@tf.keras.utils.register_keras_serializable(package='Vision')
......
......@@ -697,7 +697,7 @@ class GlobalAveragePool3D(tf.keras.layers.Layer):
def call(self,
inputs: tf.Tensor,
states: Optional[States] = None,
output_states: bool = True
output_states: bool = False
) -> Union[tf.Tensor, Tuple[tf.Tensor, States]]:
"""Calls the layer with the given inputs.
......@@ -813,13 +813,14 @@ class SpatialAveragePool3D(tf.keras.layers.Layer):
super(SpatialAveragePool3D, self).build(input_shape)
def call(self, inputs):
def call(self, inputs, states=None, output_states: bool = False):
"""Calls the layer with the given inputs."""
if inputs.shape.rank != 5:
raise ValueError(
'Input should have rank {}, got {}'.format(5, inputs.shape.rank))
return tf.reduce_mean(inputs, axis=(2, 3), keepdims=self._keepdims)
output = tf.reduce_mean(inputs, axis=(2, 3), keepdims=self._keepdims)
return (output, states) if output_states else output
class CausalConvMixin:
......
......@@ -134,14 +134,14 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
inputs = tf.range(4, dtype=tf.float32) + 1.
inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
inputs = tf.tile(inputs, [1, 1, 2, 2, 3])
expected, _ = gap(inputs)
expected, _ = gap(inputs, output_states=True)
for num_splits in [1, 2, 4]:
frames = tf.split(inputs, num_splits, axis=1)
states = {}
predicted = None
for frame in frames:
predicted, states = gap(frame, states=states)
predicted, states = gap(frame, states=states, output_states=True)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
......@@ -155,14 +155,14 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
inputs = tf.range(4, dtype=tf.float32) + 1.
inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
inputs = tf.tile(inputs, [1, 1, 2, 2, 3])
expected, _ = gap(inputs)
expected, _ = gap(inputs, output_states=True)
for num_splits in [1, 2, 4]:
frames = tf.split(inputs, num_splits, axis=1)
states = {}
predicted = []
for frame in frames:
x, states = gap(frame, states=states)
x, states = gap(frame, states=states, output_states=True)
predicted.append(x)
predicted = tf.concat(predicted, axis=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment