"test/vscode:/vscode.git/clone" did not exist on "9a8ee8a39a0aa6059c55faba05f6abb904fff6dd"
Commit a04d9e0e authored by Vishnu Banna's avatar Vishnu Banna
Browse files

merged

parents 64f16d61 bcbce005
...@@ -27,7 +27,7 @@ class VideoClassificationModel(tf.keras.Model): ...@@ -27,7 +27,7 @@ class VideoClassificationModel(tf.keras.Model):
self, self,
backbone: tf.keras.Model, backbone: tf.keras.Model,
num_classes: int, num_classes: int,
input_specs: Mapping[str, tf.keras.layers.InputSpec] = None, input_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None,
dropout_rate: float = 0.0, dropout_rate: float = 0.0,
aggregate_endpoints: bool = False, aggregate_endpoints: bool = False,
kernel_initializer: str = 'random_uniform', kernel_initializer: str = 'random_uniform',
......
...@@ -411,7 +411,7 @@ class _ApplyEdgeWeight(layers.Layer): ...@@ -411,7 +411,7 @@ class _ApplyEdgeWeight(layers.Layer):
def __init__(self, def __init__(self,
weights_shape, weights_shape,
index: int = None, index: Optional[int] = None,
use_5d_mode: bool = False, use_5d_mode: bool = False,
model_edge_weights: Optional[List[Any]] = None, model_edge_weights: Optional[List[Any]] = None,
**kwargs): **kwargs):
...@@ -471,7 +471,7 @@ class _ApplyEdgeWeight(layers.Layer): ...@@ -471,7 +471,7 @@ class _ApplyEdgeWeight(layers.Layer):
def call(self, def call(self,
inputs: List[tf.Tensor], inputs: List[tf.Tensor],
training: bool = None) -> Mapping[Any, List[tf.Tensor]]: training: Optional[bool] = None) -> Mapping[Any, List[tf.Tensor]]:
use_5d_mode = self._use_5d_mode use_5d_mode = self._use_5d_mode
dtype = inputs[0].dtype dtype = inputs[0].dtype
assert len(inputs) > 1 assert len(inputs) > 1
...@@ -517,7 +517,7 @@ class _ApplyEdgeWeight(layers.Layer): ...@@ -517,7 +517,7 @@ class _ApplyEdgeWeight(layers.Layer):
def multi_connection_fusion(inputs: List[tf.Tensor], def multi_connection_fusion(inputs: List[tf.Tensor],
index: int = None, index: Optional[int] = None,
use_5d_mode: bool = False, use_5d_mode: bool = False,
model_edge_weights: Optional[List[Any]] = None): model_edge_weights: Optional[List[Any]] = None):
"""Do weighted summation of multiple different sized tensors. """Do weighted summation of multiple different sized tensors.
...@@ -893,7 +893,8 @@ class AssembleNetModel(tf.keras.Model): ...@@ -893,7 +893,8 @@ class AssembleNetModel(tf.keras.Model):
num_classes, num_classes,
num_frames: int, num_frames: int,
model_structure: List[Any], model_structure: List[Any],
input_specs: Mapping[str, tf.keras.layers.InputSpec] = None, input_specs: Optional[Mapping[str,
tf.keras.layers.InputSpec]] = None,
max_pool_preditions: bool = False, max_pool_preditions: bool = False,
**kwargs): **kwargs):
if not input_specs: if not input_specs:
...@@ -1018,7 +1019,8 @@ def build_assemblenet_v1( ...@@ -1018,7 +1019,8 @@ def build_assemblenet_v1(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
backbone_config: hyperparams.Config, backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config, norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
"""Builds assemblenet backbone.""" """Builds assemblenet backbone."""
del l2_regularizer del l2_regularizer
...@@ -1058,7 +1060,7 @@ def build_assemblenet_model( ...@@ -1058,7 +1060,7 @@ def build_assemblenet_model(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config: cfg.AssembleNetModel, model_config: cfg.AssembleNetModel,
num_classes: int, num_classes: int,
l2_regularizer: tf.keras.regularizers.Regularizer = None): l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None):
"""Builds assemblenet model.""" """Builds assemblenet model."""
input_specs_dict = {'image': input_specs} input_specs_dict = {'image': input_specs}
backbone = build_assemblenet_v1(input_specs, model_config.backbone, backbone = build_assemblenet_v1(input_specs, model_config.backbone,
......
...@@ -8,16 +8,27 @@ This repository is the official implementation of ...@@ -8,16 +8,27 @@ This repository is the official implementation of
[MoViNets: Mobile Video Networks for Efficient Video [MoViNets: Mobile Video Networks for Efficient Video
Recognition](https://arxiv.org/abs/2103.11511). Recognition](https://arxiv.org/abs/2103.11511).
<p align="center">
<img src="https://storage.googleapis.com/tf_model_garden/vision/movinet/artifacts/hoverboard_stream.gif" height=500>
</p>
## Description ## Description
Mobile Video Networks (MoViNets) are efficient video classification models Mobile Video Networks (MoViNets) are efficient video classification models
runnable on mobile devices. MoViNets demonstrate state-of-the-art accuracy and runnable on mobile devices. MoViNets demonstrate state-of-the-art accuracy and
efficiency on several large-scale video action recognition datasets. efficiency on several large-scale video action recognition datasets.
On [Kinetics 600](https://deepmind.com/research/open-source/kinetics),
MoViNet-A6 achieves 84.8% top-1 accuracy, outperforming recent
Vision Transformer models like [ViViT](https://arxiv.org/abs/2103.15691) (83.0%)
and [VATT](https://arxiv.org/abs/2104.11178) (83.6%) without any additional
training data, while using 10x fewer FLOPs. And streaming MoViNet-A0 achieves
72% accuracy while using 3x fewer FLOPs than MobileNetV3-large (68%).
There is a large gap between video model performance of accurate models and There is a large gap between video model performance of accurate models and
efficient models for video action recognition. On the one hand, 2D MobileNet efficient models for video action recognition. On the one hand, 2D MobileNet
CNNs are fast and can operate on streaming video in real time, but are prone to CNNs are fast and can operate on streaming video in real time, but are prone to
be noisy and are inaccurate. On the other hand, 3D CNNs are accurate, but are be noisy and inaccurate. On the other hand, 3D CNNs are accurate, but are
memory and computation intensive and cannot operate on streaming video. memory and computation intensive and cannot operate on streaming video.
MoViNets bridge this gap, producing: MoViNets bridge this gap, producing:
...@@ -28,19 +39,22 @@ to A6). ...@@ -28,19 +39,22 @@ to A6).
usage. usage.
- Temporal ensembles of models to boost efficiency even higher. - Temporal ensembles of models to boost efficiency even higher.
Small MoViNets demonstrate higher efficiency and accuracy than MobileNetV3 for MoViNets also improve computational efficiency by outputting high-quality
video action recognition (Kinetics 600). predictions frame by frame, as opposed to the traditional multi-clip evaluation
approach that performs redundant computation and limits temporal scope.
MoViNets also improve efficiency by outputting high-quality predictions with a <p align="center">
single frame, as opposed to the traditional multi-clip evaluation approach. <img src="https://storage.googleapis.com/tf_model_garden/vision/movinet/artifacts/movinet_multi_clip_eval.png" height=200>
</p>
[![Multi-Clip Eval](https://storage.googleapis.com/tf_model_garden/vision/movinet/artifacts/movinet_multi_clip_eval.png)](https://arxiv.org/pdf/2103.11511.pdf) <p align="center">
<img src="https://storage.googleapis.com/tf_model_garden/vision/movinet/artifacts/movinet_stream_eval.png" height=200>
[![Streaming Eval](https://storage.googleapis.com/tf_model_garden/vision/movinet/artifacts/movinet_stream_eval.png)](https://arxiv.org/pdf/2103.11511.pdf) </p>
## History ## History
- Initial Commit. - **2021-05-30** Add streaming MoViNet checkpoints and examples.
- **2021-05-11** Initial Commit.
## Authors and Maintainers ## Authors and Maintainers
...@@ -53,6 +67,7 @@ single frame, as opposed to the traditional multi-clip evaluation approach. ...@@ -53,6 +67,7 @@ single frame, as opposed to the traditional multi-clip evaluation approach.
- [Requirements](#requirements) - [Requirements](#requirements)
- [Results and Pretrained Weights](#results-and-pretrained-weights) - [Results and Pretrained Weights](#results-and-pretrained-weights)
- [Kinetics 600](#kinetics-600) - [Kinetics 600](#kinetics-600)
- [Prediction Examples](#prediction-examples)
- [Training and Evaluation](#training-and-evaluation) - [Training and Evaluation](#training-and-evaluation)
- [References](#references) - [References](#references)
- [License](#license) - [License](#license)
...@@ -76,33 +91,154 @@ pip install -r requirements.txt ...@@ -76,33 +91,154 @@ pip install -r requirements.txt
### Kinetics 600 ### Kinetics 600
[![MoViNet Comparison](https://storage.googleapis.com/tf_model_garden/vision/movinet/artifacts/movinet_comparison.png)](https://arxiv.org/pdf/2103.11511.pdf) <p align="center">
<img src="https://storage.googleapis.com/tf_model_garden/vision/movinet/artifacts/movinet_comparison.png" height=500>
</p>
[tensorboard.dev summary](https://tensorboard.dev/experiment/Q07RQUlVRWOY4yDw3SnSkA/) [tensorboard.dev summary](https://tensorboard.dev/experiment/Q07RQUlVRWOY4yDw3SnSkA/)
of training runs across all models. of training runs across all models.
The table below summarizes the performance of each model and provides links to The table below summarizes the performance of each model on
download pretrained models. All models are evaluated on single clips with the [Kinetics 600](https://deepmind.com/research/open-source/kinetics)
same resolution as training. and provides links to download pretrained models. All models are evaluated on
single clips with the same resolution as training.
Note: MoViNet-A6 can be constructed as an ensemble of MoViNet-A4 and
MoViNet-A5.
Streaming MoViNets will be added in the future. #### Base Models
| Model Name | Top-1 Accuracy | Top-5 Accuracy | GFLOPs\* | Checkpoint | TF Hub SavedModel | Base models implement standard 3D convolutions without stream buffers.
|------------|----------------|----------------|----------|------------|-------------------|
| MoViNet-A0-Base | 71.41 | 90.91 | 2.7 | [checkpoint (12 MiB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a0_base.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a0/base/kinetics-600/classification/) | | Model Name | Top-1 Accuracy | Top-5 Accuracy | Input Shape | GFLOPs\* | Chekpoint | TF Hub SavedModel |
| MoViNet-A1-Base | 76.01 | 93.28 | 6.0 | [checkpoint (18 MiB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a1_base.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a1/base/kinetics-600/classification/) | |------------|----------------|----------------|-------------|----------|-----------|-------------------|
| MoViNet-A2-Base | 78.03 | 93.99 | 10 | [checkpoint (20 MiB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a2_base.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a2/base/kinetics-600/classification/) | | MoViNet-A0-Base | 72.28 | 90.92 | 50 x 172 x 172 | 2.7 | [checkpoint (12 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a0_base.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a0/base/kinetics-600/classification/) |
| MoViNet-A3-Base | 81.22 | 95.35 | 57 | [checkpoint (29 MiB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a3_base.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a3/base/kinetics-600/classification/) | | MoViNet-A1-Base | 76.69 | 93.40 | 50 x 172 x 172 | 6.0 | [checkpoint (18 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a1_base.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a1/base/kinetics-600/classification/) |
| MoViNet-A4-Base | 82.96 | 95.98 | 110 | [checkpoint (44 MiB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a4_base.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a4/base/kinetics-600/classification/) | | MoViNet-A2-Base | 78.62 | 94.17 | 50 x 224 x 224 | 10 | [checkpoint (20 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a2_base.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a2/base/kinetics-600/classification/) |
| MoViNet-A5-Base | 84.22 | 96.36 | 280 | [checkpoint (72 MiB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a5_base.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a5/base/kinetics-600/classification/) | | MoViNet-A3-Base | 81.79 | 95.67 | 120 x 256 x 256 | 57 | [checkpoint (29 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a3_base.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a3/base/kinetics-600/classification/) |
| MoViNet-A4-Base | 83.48 | 96.16 | 80 x 290 x 290 | 110 | [checkpoint (44 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a4_base.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a4/base/kinetics-600/classification/) |
| MoViNet-A5-Base | 84.27 | 96.39 | 120 x 320 x 320 | 280 | [checkpoint (72 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a5_base.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a5/base/kinetics-600/classification/) |
\*GFLOPs per video on Kinetics 600. \*GFLOPs per video on Kinetics 600.
## Training and Evaluation #### Streaming Models
Streaming models implement causal 3D convolutions with stream buffers.
| Model Name | Top-1 Accuracy | Top-5 Accuracy | Input Shape\* | GFLOPs\*\* | Chekpoint | TF Hub SavedModel |
|------------|----------------|----------------|---------------|------------|-----------|-------------------|
| MoViNet-A0-Stream | 72.05 | 90.63 | 50 x 172 x 172 | 2.7 | [checkpoint (12 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a0_stream.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a0/stream/kinetics-600/classification/) |
| MoViNet-A1-Stream | 76.45 | 93.25 | 50 x 172 x 172 | 6.0 | [checkpoint (18 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a1_stream.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a1/stream/kinetics-600/classification/) |
| MoViNet-A2-Stream | 78.40 | 94.05 | 50 x 224 x 224 | 10 | [checkpoint (20 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a2_stream.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a2/stream/kinetics-600/classification/) |
| MoViNet-A3-Stream | 80.09 | 94.84 | 120 x 256 x 256 | 57 | [checkpoint (29 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a3_stream.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a3/stream/kinetics-600/classification/) |
| MoViNet-A4-Stream | 81.49 | 95.66 | 80 x 290 x 290 | 110 | [checkpoint (44 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a4_stream.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a4/stream/kinetics-600/classification/) |
| MoViNet-A5-Stream | 82.37 | 95.79 | 120 x 320 x 320 | 280 | [checkpoint (72 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a5_stream.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a5/stream/kinetics-600/classification/) |
\*In streaming mode, the number of frames correspond to the total accumulated
duration of the 10-second clip.
\*\*GFLOPs per video on Kinetics 600.
## Prediction Examples
Please check out our [Colab Notebook](https://colab.research.google.com/github/tensorflow/models/tree/master/official/vision/beta/projects/movinet/movinet_tutorial.ipynb) Please check out our [Colab Notebook](https://colab.research.google.com/github/tensorflow/models/tree/master/official/vision/beta/projects/movinet/movinet_tutorial.ipynb)
to get started with MoViNets. to get started with MoViNets.
This section provides examples on how to run prediction.
For base models, run the following:
```python
import tensorflow as tf
from official.vision.beta.projects.movinet.modeling import movinet
from official.vision.beta.projects.movinet.modeling import movinet_model
# Create backbone and model.
backbone = movinet.Movinet(
model_id='a0',
causal=True,
use_external_states=True,
)
model = movinet_model.MovinetClassifier(
backbone, num_classes=600, output_states=True)
# Create your example input here.
# Refer to the paper for recommended input shapes.
inputs = tf.ones([1, 8, 172, 172, 3])
# [Optional] Build the model and load a pretrained checkpoint
model.build(inputs.shape)
checkpoint_dir = '/path/to/checkpoint'
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint = tf.train.Checkpoint(model=model)
status = checkpoint.restore(checkpoint_path)
status.assert_existing_objects_matched()
# Run the model prediction.
output = model(inputs)
prediction = tf.argmax(output, -1)
```
For streaming models, run the following:
```python
import tensorflow as tf
from official.vision.beta.projects.movinet.modeling import movinet
from official.vision.beta.projects.movinet.modeling import movinet_model
# Create backbone and model.
backbone = movinet.Movinet(
model_id='a0',
causal=True,
use_external_states=True,
)
model = movinet_model.MovinetClassifier(
backbone, num_classes=600, output_states=True)
# Create your example input here.
# Refer to the paper for recommended input shapes.
inputs = tf.ones([1, 8, 172, 172, 3])
# [Optional] Build the model and load a pretrained checkpoint
model.build(inputs.shape)
checkpoint_dir = '/path/to/checkpoint'
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint = tf.train.Checkpoint(model=model)
status = checkpoint.restore(checkpoint_path)
status.assert_existing_objects_matched()
# Split the video into individual frames.
# Note: we can also split into larger clips as well (e.g., 8-frame clips).
# Running on larger clips will slightly reduce latency overhead, but
# will consume more memory.
frames = tf.split(inputs, inputs.shape[1], axis=1)
# Initialize the dict of states. All state tensors are initially zeros.
init_states = model.init_states(tf.shape(inputs))
# Run the model prediction by looping over each frame.
states = init_states
predictions = []
for frame in frames:
output, states = model({**states, 'image': frame})
predictions.append(output)
# The video classification will simply be the last output of the model.
final_prediction = tf.argmax(predictions[-1], -1)
# Alternatively, we can run the network on the entire input video.
# The output should be effectively the same
# (but it may differ a small amount due to floating point errors).
non_streaming_output, _ = model({**init_states, 'image': inputs})
non_streaming_prediction = tf.argmax(non_streaming_output, -1)
```
## Training and Evaluation
Run this command line for continuous training and evaluation. Run this command line for continuous training and evaluation.
```shell ```shell
...@@ -137,11 +273,6 @@ python3 official/vision/beta/projects/movinet/train.py \ ...@@ -137,11 +273,6 @@ python3 official/vision/beta/projects/movinet/train.py \
--tf_data_service="" --tf_data_service=""
``` ```
## References
- [Kinetics Datasets](https://deepmind.com/research/open-source/kinetics)
- [MoViNets (Mobile Video Networks)](https://arxiv.org/abs/2103.11511)
## License ## License
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
......
...@@ -45,6 +45,7 @@ class Movinet(hyperparams.Config): ...@@ -45,6 +45,7 @@ class Movinet(hyperparams.Config):
# 3d_2plus1d: (2+1)D convolution with Conv3D (no 2D reshaping) # 3d_2plus1d: (2+1)D convolution with Conv3D (no 2D reshaping)
conv_type: str = '3d' conv_type: str = '3d'
stochastic_depth_drop_rate: float = 0.2 stochastic_depth_drop_rate: float = 0.2
use_external_states: bool = False
@dataclasses.dataclass @dataclasses.dataclass
......
# Video classification on Kinetics-600 using MoViNet-A5-Stream backbone.
# --experiment_type=movinet_kinetics600
# Achieves 82.37% Top-1 accuracy.
# http://mldash/experiments/7675567202035803461
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
losses:
l2_weight_decay: 0.00003
label_smoothing: 0.1
model:
backbone:
movinet:
model_id: 'a5'
causal: true
use_positional_encoding: true
stochastic_depth_drop_rate: 0.2
norm_activation:
use_sync_bn: true
dropout_rate: 0.5
train_data:
name: kinetics600
variant_name: rgb
feature_shape: !!python/tuple
- 32
- 320
- 320
- 3
temporal_stride: 2
random_stride_range: 1
global_batch_size: 1024
dtype: 'bfloat16'
shuffle_buffer_size: 1024
min_image_size: 368
aug_max_area_ratio: 1.0
aug_max_aspect_ratio: 2.0
aug_min_area_ratio: 0.08
aug_min_aspect_ratio: 0.5
aug_type: 'autoaug'
validation_data:
name: kinetics600
feature_shape: !!python/tuple
- 120
- 320
- 320
- 3
temporal_stride: 2
num_test_clips: 1
num_test_crops: 1
global_batch_size: 32
min_image_size: 368
dtype: 'bfloat16'
drop_remainder: false
trainer:
optimizer_config:
learning_rate:
cosine:
initial_learning_rate: 1.8
decay_steps: 85785
warmup:
linear:
warmup_steps: 2145
optimizer:
type: 'rmsprop'
rmsprop:
rho: 0.9
momentum: 0.9
epsilon: 1.0
clipnorm: 1.0
train_steps: 85785
steps_per_loop: 500
summary_interval: 500
validation_interval: 500
...@@ -19,38 +19,18 @@ Export example: ...@@ -19,38 +19,18 @@ Export example:
```shell ```shell
python3 export_saved_model.py \ python3 export_saved_model.py \
--output_path=/tmp/movinet/ \ --export_path=/tmp/movinet/ \
--model_id=a0 \ --model_id=a0 \
--causal=True \ --causal=True \
--conv_type="3d" \ --conv_type="3d" \
--num_classes=600 \ --num_classes=600 \
--use_positional_encoding=False \
--checkpoint_path="" --checkpoint_path=""
``` ```
To use an exported saved_model in various applications: To use an exported saved_model, refer to export_saved_model_test.py.
```python
import tensorflow as tf
import tensorflow_hub as hub
saved_model_path = ...
inputs = tf.keras.layers.Input(
shape=[None, None, None, 3],
dtype=tf.float32)
encoder = hub.KerasLayer(saved_model_path, trainable=True)
outputs = encoder(inputs)
model = tf.keras.Model(inputs, outputs)
example_input = tf.ones([1, 8, 172, 172, 3])
outputs = model(example_input, states)
```
""" """
from typing import Sequence
from absl import app from absl import app
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
...@@ -59,8 +39,8 @@ from official.vision.beta.projects.movinet.modeling import movinet ...@@ -59,8 +39,8 @@ from official.vision.beta.projects.movinet.modeling import movinet
from official.vision.beta.projects.movinet.modeling import movinet_model from official.vision.beta.projects.movinet.modeling import movinet_model
flags.DEFINE_string( flags.DEFINE_string(
'output_path', '/tmp/movinet/', 'export_path', '/tmp/movinet/',
'Path to saved exported saved_model file.') 'Export path to save the saved_model file.')
flags.DEFINE_string( flags.DEFINE_string(
'model_id', 'a0', 'MoViNet model name.') 'model_id', 'a0', 'MoViNet model name.')
flags.DEFINE_bool( flags.DEFINE_bool(
...@@ -73,8 +53,20 @@ flags.DEFINE_string( ...@@ -73,8 +53,20 @@ flags.DEFINE_string(
'3x3 followed by 5x1 conv). 3d_2plus1d uses (2+1)D convolution with ' '3x3 followed by 5x1 conv). 3d_2plus1d uses (2+1)D convolution with '
'Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 ' 'Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 '
'followed by 5x1x1 conv).') 'followed by 5x1x1 conv).')
flags.DEFINE_bool(
'use_positional_encoding', False,
'Whether to use positional encoding (only applied when causal=True).')
flags.DEFINE_integer( flags.DEFINE_integer(
'num_classes', 600, 'The number of classes for prediction.') 'num_classes', 600, 'The number of classes for prediction.')
flags.DEFINE_integer(
'batch_size', None,
'The batch size of the input. Set to None for dynamic input.')
flags.DEFINE_integer(
'num_frames', None,
'The number of frames of the input. Set to None for dynamic input.')
flags.DEFINE_integer(
'image_size', None,
'The resolution of the input. Set to None for dynamic input.')
flags.DEFINE_string( flags.DEFINE_string(
'checkpoint_path', '', 'checkpoint_path', '',
'Checkpoint path to load. Leave blank for default initialization.') 'Checkpoint path to load. Leave blank for default initialization.')
...@@ -82,75 +74,79 @@ flags.DEFINE_string( ...@@ -82,75 +74,79 @@ flags.DEFINE_string(
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def main(argv: Sequence[str]) -> None: def main(_) -> None:
if len(argv) > 1: input_specs = tf.keras.layers.InputSpec(shape=[
raise app.UsageError('Too many command-line arguments.') FLAGS.batch_size,
FLAGS.num_frames,
FLAGS.image_size,
FLAGS.image_size,
3,
])
# Use dimensions of 1 except the channels to export faster, # Use dimensions of 1 except the channels to export faster,
# since we only really need the last dimension to build and get the output # since we only really need the last dimension to build and get the output
# states. These dimensions will be set to `None` once the model is built. # states. These dimensions will be set to `None` once the model is built.
input_shape = [1, 1, 1, 1, 3] input_shape = [1 if s is None else s for s in input_specs.shape]
backbone = movinet.Movinet( backbone = movinet.Movinet(
FLAGS.model_id, causal=FLAGS.causal, conv_type=FLAGS.conv_type) FLAGS.model_id,
causal=FLAGS.causal,
conv_type=FLAGS.conv_type,
use_external_states=FLAGS.causal,
input_specs=input_specs,
use_positional_encoding=FLAGS.use_positional_encoding)
model = movinet_model.MovinetClassifier( model = movinet_model.MovinetClassifier(
backbone, num_classes=FLAGS.num_classes, output_states=FLAGS.causal) backbone,
num_classes=FLAGS.num_classes,
output_states=FLAGS.causal,
input_specs=dict(image=input_specs))
model.build(input_shape) model.build(input_shape)
# Compile model to generate some internal Keras variables.
model.compile()
if FLAGS.checkpoint_path: if FLAGS.checkpoint_path:
model.load_weights(FLAGS.checkpoint_path) checkpoint = tf.train.Checkpoint(model=model)
status = checkpoint.restore(FLAGS.checkpoint_path)
status.assert_existing_objects_matched()
if FLAGS.causal: if FLAGS.causal:
# Call the model once to get the output states. Call again with `states` # Call the model once to get the output states. Call again with `states`
# input to ensure that the inputs with the `states` argument is built # input to ensure that the inputs with the `states` argument is built
_, states = model(dict(image=tf.ones(input_shape), states={})) # with the full output state shapes.
_, states = model(dict(image=tf.ones(input_shape), states=states)) input_image = tf.ones(input_shape)
_, states = model({**model.init_states(input_shape), 'image': input_image})
input_spec = tf.TensorSpec( _, states = model({**states, 'image': input_image})
shape=[None, None, None, None, 3],
dtype=tf.float32, # Create a function to explicitly set the names of the outputs
name='inputs') def predict(inputs):
outputs, states = model(inputs)
state_specs = {} return {**states, 'logits': outputs}
for name, state in states.items():
shape = state.shape specs = {
if len(state.shape) == 5: name: tf.TensorSpec(spec.shape, name=name, dtype=spec.dtype)
shape = [None, state.shape[1], None, None, state.shape[-1]] for name, spec in model.initial_state_specs(
new_spec = tf.TensorSpec(shape=shape, dtype=state.dtype, name=name) input_specs.shape).items()
state_specs[name] = new_spec }
specs['image'] = tf.TensorSpec(
specs = (input_spec, state_specs) input_specs.shape, dtype=model.dtype, name='image')
# Define a tf.keras.Model with custom signatures to allow it to accept predict_fn = tf.function(predict, jit_compile=True)
# a state dict as an argument. We define it inline here because predict_fn = predict_fn.get_concrete_function(specs)
# we first need to determine the shape of the state tensors before
# applying the `input_signature` argument to `tf.function`. init_states_fn = tf.function(model.init_states, jit_compile=True)
class ExportStateModule(tf.Module): init_states_fn = init_states_fn.get_concrete_function(
"""Module with state for exporting to saved_model.""" tf.TensorSpec([5], dtype=tf.int32))
def __init__(self, model): signatures = {'call': predict_fn, 'init_states': init_states_fn}
self.model = model
tf.keras.models.save_model(
@tf.function(input_signature=[input_spec]) model, FLAGS.export_path, signatures=signatures)
def __call__(self, inputs):
return self.model(dict(image=inputs, states={}))
@tf.function(input_signature=[input_spec])
def base(self, inputs):
return self.model(dict(image=inputs, states={}))
@tf.function(input_signature=specs)
def stream(self, inputs, states):
return self.model(dict(image=inputs, states=states))
module = ExportStateModule(model)
tf.saved_model.save(module, FLAGS.output_path)
else: else:
_ = model(tf.ones(input_shape)) _ = model(tf.ones(input_shape))
tf.keras.models.save_model(model, FLAGS.output_path) tf.keras.models.save_model(model, FLAGS.export_path)
print(' ----- Done. Saved Model is saved at {}'.format(FLAGS.output_path)) print(' ----- Done. Saved Model is saved at {}'.format(FLAGS.export_path))
if __name__ == '__main__': if __name__ == '__main__':
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for export_saved_model."""
from absl import flags
import tensorflow as tf
import tensorflow_hub as hub
from official.vision.beta.projects.movinet import export_saved_model
FLAGS = flags.FLAGS
class ExportSavedModelTest(tf.test.TestCase):
def test_movinet_export_a0_base_with_tfhub(self):
saved_model_path = self.get_temp_dir()
FLAGS.export_path = saved_model_path
FLAGS.model_id = 'a0'
FLAGS.causal = False
FLAGS.num_classes = 600
export_saved_model.main('unused_args')
encoder = hub.KerasLayer(saved_model_path, trainable=True)
inputs = tf.keras.layers.Input(
shape=[None, None, None, 3],
dtype=tf.float32)
outputs = encoder(dict(image=inputs))
model = tf.keras.Model(inputs, outputs)
example_input = tf.ones([1, 8, 172, 172, 3])
outputs = model(example_input)
self.assertEqual(outputs.shape, [1, 600])
def test_movinet_export_a0_stream_with_tfhub(self):
saved_model_path = self.get_temp_dir()
FLAGS.export_path = saved_model_path
FLAGS.model_id = 'a0'
FLAGS.causal = True
FLAGS.num_classes = 600
export_saved_model.main('unused_args')
encoder = hub.KerasLayer(saved_model_path, trainable=True)
image_input = tf.keras.layers.Input(
shape=[None, None, None, 3],
dtype=tf.float32,
name='image')
init_states_fn = encoder.resolved_object.signatures['init_states']
state_shapes = {
name: ([s if s > 0 else None for s in state.shape], state.dtype)
for name, state in init_states_fn(tf.constant([0, 0, 0, 0, 3])).items()
}
states_input = {
name: tf.keras.Input(shape[1:], dtype=dtype, name=name)
for name, (shape, dtype) in state_shapes.items()
}
inputs = {**states_input, 'image': image_input}
outputs = encoder(inputs)
model = tf.keras.Model(inputs, outputs)
example_input = tf.ones([1, 8, 172, 172, 3])
frames = tf.split(example_input, example_input.shape[1], axis=1)
init_states = init_states_fn(tf.shape(example_input))
expected_outputs, _ = model({**init_states, 'image': example_input})
states = init_states
for frame in frames:
outputs, states = model({**states, 'image': frame})
self.assertEqual(outputs.shape, [1, 600])
self.assertNotEmpty(states)
self.assertAllClose(outputs, expected_outputs, 1e-5, 1e-5)
if __name__ == '__main__':
tf.test.main()
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
Reference: https://arxiv.org/pdf/2103.11511.pdf Reference: https://arxiv.org/pdf/2103.11511.pdf
""" """
from typing import Optional, Sequence, Tuple import math
from typing import Dict, Mapping, Optional, Sequence, Tuple, Union
import dataclasses import dataclasses
import tensorflow as tf import tensorflow as tf
...@@ -71,8 +72,6 @@ class HeadSpec(BlockSpec): ...@@ -71,8 +72,6 @@ class HeadSpec(BlockSpec):
"""Configuration of a Movinet block.""" """Configuration of a Movinet block."""
project_filters: int = 0 project_filters: int = 0
head_filters: int = 0 head_filters: int = 0
output_per_frame: bool = False
max_pool_predictions: bool = False
# Block specs specify the architecture of each model # Block specs specify the architecture of each model
...@@ -317,6 +316,7 @@ class Movinet(tf.keras.Model): ...@@ -317,6 +316,7 @@ class Movinet(tf.keras.Model):
kernel_regularizer: Optional[str] = None, kernel_regularizer: Optional[str] = None,
bias_regularizer: Optional[str] = None, bias_regularizer: Optional[str] = None,
stochastic_depth_drop_rate: float = 0., stochastic_depth_drop_rate: float = 0.,
use_external_states: bool = False,
**kwargs): **kwargs):
"""MoViNet initialization function. """MoViNet initialization function.
...@@ -344,6 +344,8 @@ class Movinet(tf.keras.Model): ...@@ -344,6 +344,8 @@ class Movinet(tf.keras.Model):
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d. bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Defaults to None. Defaults to None.
stochastic_depth_drop_rate: the base rate for stochastic depth. stochastic_depth_drop_rate: the base rate for stochastic depth.
use_external_states: if True, expects states to be passed as additional
input.
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
block_specs = BLOCK_SPECS[model_id] block_specs = BLOCK_SPECS[model_id]
...@@ -371,7 +373,10 @@ class Movinet(tf.keras.Model): ...@@ -371,7 +373,10 @@ class Movinet(tf.keras.Model):
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._use_external_states = use_external_states
if self._use_external_states and not self._causal:
raise ValueError('External states should be used with causal mode.')
if not isinstance(block_specs[0], StemSpec): if not isinstance(block_specs[0], StemSpec):
raise ValueError( raise ValueError(
'Expected first spec to be StemSpec, got {}'.format(block_specs[0])) 'Expected first spec to be StemSpec, got {}'.format(block_specs[0]))
...@@ -380,22 +385,55 @@ class Movinet(tf.keras.Model): ...@@ -380,22 +385,55 @@ class Movinet(tf.keras.Model):
'Expected final spec to be HeadSpec, got {}'.format(block_specs[-1])) 'Expected final spec to be HeadSpec, got {}'.format(block_specs[-1]))
self._head_filters = block_specs[-1].head_filters self._head_filters = block_specs[-1].head_filters
if tf.keras.backend.image_data_format() == 'channels_last': state_specs = None
bn_axis = -1 if use_external_states:
else: self._set_dtype_policy(input_specs.dtype)
bn_axis = 1 state_specs = self.initial_state_specs(input_specs.shape)
# Build MoViNet backbone. inputs, outputs = self._build_network(input_specs, state_specs=state_specs)
inputs = tf.keras.Input(shape=input_specs.shape[1:], name='inputs')
x = inputs super(Movinet, self).__init__(inputs=inputs, outputs=outputs, **kwargs)
states = {}
self._state_specs = state_specs
def _build_network(
self,
input_specs: tf.keras.layers.InputSpec,
state_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None,
) -> Tuple[Mapping[str, tf.keras.Input], Tuple[Mapping[str, tf.Tensor],
Mapping[str, tf.Tensor]]]:
"""Builds the model network.
Args:
input_specs: the model input spec to use.
state_specs: a dict mapping a state name to the corresponding state spec.
State names should match with the `state` input/output dict.
Returns:
Inputs and outputs as a tuple. Inputs are expected to be a dict with
base input and states. Outputs are expected to be a dict of endpoints
and output states.
"""
state_specs = state_specs if state_specs is not None else {}
image_input = tf.keras.Input(shape=input_specs.shape[1:], name='inputs')
states = {
name: tf.keras.Input(shape=spec.shape[1:], dtype=spec.dtype, name=name)
for name, spec in state_specs.items()
}
inputs = {**states, 'image': image_input}
endpoints = {} endpoints = {}
num_layers = sum(len(block.expand_filters) for block in block_specs x = image_input
if isinstance(block, MovinetBlockSpec))
num_layers = sum(
len(block.expand_filters)
for block in self._block_specs
if isinstance(block, MovinetBlockSpec))
stochastic_depth_idx = 1 stochastic_depth_idx = 1
for block_idx, block in enumerate(block_specs): for block_idx, block in enumerate(self._block_specs):
if isinstance(block, StemSpec): if isinstance(block, StemSpec):
x, states = movinet_layers.Stem( x, states = movinet_layers.Stem(
block.filters, block.filters,
...@@ -404,12 +442,14 @@ class Movinet(tf.keras.Model): ...@@ -404,12 +442,14 @@ class Movinet(tf.keras.Model):
conv_type=self._conv_type, conv_type=self._conv_type,
causal=self._causal, causal=self._causal,
activation=self._activation, activation=self._activation,
kernel_initializer=kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
batch_norm_layer=self._norm, batch_norm_layer=self._norm,
batch_norm_momentum=self._norm_momentum, batch_norm_momentum=self._norm_momentum,
batch_norm_epsilon=self._norm_epsilon, batch_norm_epsilon=self._norm_epsilon,
name='stem')(x, states=states) state_prefix='state/stem',
name='stem')(
x, states=states)
endpoints['stem'] = x endpoints['stem'] = x
elif isinstance(block, MovinetBlockSpec): elif isinstance(block, MovinetBlockSpec):
if not (len(block.expand_filters) == len(block.kernel_sizes) == if not (len(block.expand_filters) == len(block.kernel_sizes) ==
...@@ -437,14 +477,16 @@ class Movinet(tf.keras.Model): ...@@ -437,14 +477,16 @@ class Movinet(tf.keras.Model):
activation=self._activation, activation=self._activation,
stochastic_depth_drop_rate=stochastic_depth_drop_rate, stochastic_depth_drop_rate=stochastic_depth_drop_rate,
conv_type=self._conv_type, conv_type=self._conv_type,
use_positional_encoding= use_positional_encoding=self._use_positional_encoding and
self._use_positional_encoding and self._causal, self._causal,
kernel_initializer=kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
batch_norm_layer=self._norm, batch_norm_layer=self._norm,
batch_norm_momentum=self._norm_momentum, batch_norm_momentum=self._norm_momentum,
batch_norm_epsilon=self._norm_epsilon, batch_norm_epsilon=self._norm_epsilon,
name=name)(x, states=states) state_prefix=f'state/{name}',
name=name)(
x, states=states)
endpoints[name] = x endpoints[name] = x
stochastic_depth_idx += 1 stochastic_depth_idx += 1
elif isinstance(block, HeadSpec): elif isinstance(block, HeadSpec):
...@@ -452,27 +494,163 @@ class Movinet(tf.keras.Model): ...@@ -452,27 +494,163 @@ class Movinet(tf.keras.Model):
project_filters=block.project_filters, project_filters=block.project_filters,
conv_type=self._conv_type, conv_type=self._conv_type,
activation=self._activation, activation=self._activation,
kernel_initializer=kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
batch_norm_layer=self._norm, batch_norm_layer=self._norm,
batch_norm_momentum=self._norm_momentum, batch_norm_momentum=self._norm_momentum,
batch_norm_epsilon=self._norm_epsilon)(x, states=states) batch_norm_epsilon=self._norm_epsilon,
state_prefix='state/head',
name='head')(
x, states=states)
endpoints['head'] = x endpoints['head'] = x
else: else:
raise ValueError('Unknown block type {}'.format(block)) raise ValueError('Unknown block type {}'.format(block))
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints} outputs = (endpoints, states)
return inputs, outputs
def _get_initial_state_shapes(
self,
block_specs: Sequence[BlockSpec],
input_shape: Union[Sequence[int], tf.Tensor],
use_positional_encoding: bool = False) -> Dict[str, Sequence[int]]:
"""Generates names and shapes for all input states.
Args:
block_specs: sequence of specs used for creating a model.
input_shape: the expected 5D shape of the image input.
use_positional_encoding: whether the model will use positional encoding.
inputs = { Returns:
'image': inputs, A dict mapping state names to state shapes.
'states': { """
name: tf.keras.Input(shape=state.shape[1:], name=f'states/{name}') def divide_resolution(shape, num_downsamples):
for name, state in states.items() """Downsamples the dimension to calculate strided convolution shape."""
}, if shape is None:
return None
if isinstance(shape, tf.Tensor):
# Avoid using div and ceil to support tf lite
shape = tf.cast(shape, tf.float32)
resolution_divisor = 2 ** num_downsamples
resolution_multiplier = 0.5 ** num_downsamples
shape = ((shape + resolution_divisor - 1) * resolution_multiplier)
return tf.cast(shape, tf.int32)
else:
resolution_divisor = 2 ** num_downsamples
return math.ceil(shape / resolution_divisor)
states = {}
num_downsamples = 0
for block_idx, block in enumerate(block_specs):
if isinstance(block, StemSpec):
if block.kernel_size[0] > 1:
states['state/stem/stream_buffer'] = (
input_shape[0],
input_shape[1],
divide_resolution(input_shape[2], num_downsamples),
divide_resolution(input_shape[3], num_downsamples),
block.filters,
)
num_downsamples += 1
elif isinstance(block, MovinetBlockSpec):
block_idx -= 1
params = list(zip(
block.expand_filters,
block.kernel_sizes,
block.strides))
for layer_idx, layer in enumerate(params):
expand_filters, kernel_size, strides = layer
# If we use a 2D kernel, we apply spatial downsampling
# before the buffer.
if (tuple(strides[1:3]) != (1, 1) and
self._conv_type in ['2plus1d', '3d_2plus1d']):
num_downsamples += 1
if kernel_size[0] > 1:
states[f'state/b{block_idx}/l{layer_idx}/stream_buffer'] = (
input_shape[0],
kernel_size[0] - 1,
divide_resolution(input_shape[2], num_downsamples),
divide_resolution(input_shape[3], num_downsamples),
expand_filters,
)
states[f'state/b{block_idx}/l{layer_idx}/pool_buffer'] = (
input_shape[0], 1, 1, 1, expand_filters,
)
states[f'state/b{block_idx}/l{layer_idx}/pool_frame_count'] = (1,)
if use_positional_encoding:
name = f'state/b{block_idx}/l{layer_idx}/pos_enc_frame_count'
states[name] = (1,)
if strides[1] != strides[2]:
raise ValueError('Strides must match in the spatial dimensions, '
'got {}'.format(strides))
# If we use a 3D kernel, we apply spatial downsampling
# after the buffer.
if (tuple(strides[1:3]) != (1, 1) and
self._conv_type not in ['2plus1d', '3d_2plus1d']):
num_downsamples += 1
elif isinstance(block, HeadSpec):
states['state/head/pool_buffer'] = (
input_shape[0], 1, 1, 1, block.project_filters,
)
states['state/head/pool_frame_count'] = (1,)
return states
def _get_state_dtype(self, name: str) -> str:
"""Returns the dtype associated with a state."""
if 'frame_count' in name:
return 'int32'
return self.dtype
def initial_state_specs(
self, input_shape: Sequence[int]) -> Dict[str, tf.keras.layers.InputSpec]:
"""Creates a mapping of state name to InputSpec from the input shape."""
state_shapes = self._get_initial_state_shapes(
self._block_specs,
input_shape,
use_positional_encoding=self._use_positional_encoding)
return {
name: tf.keras.layers.InputSpec(
shape=shape, dtype=self._get_state_dtype(name))
for name, shape in state_shapes.items()
} }
outputs = (endpoints, states)
super(Movinet, self).__init__(inputs=inputs, outputs=outputs, **kwargs) def init_states(self, input_shape: Sequence[int]) -> Dict[str, tf.Tensor]:
"""Returns initial states for the first call in steaming mode."""
state_shapes = self._get_initial_state_shapes(
self._block_specs,
input_shape,
use_positional_encoding=self._use_positional_encoding)
states = {
name: tf.zeros(shape, dtype=self._get_state_dtype(name))
for name, shape in state_shapes.items()
}
return states
@property
def use_external_states(self) -> bool:
"""Whether this model is expecting input states as additional input."""
return self._use_external_states
@property
def head_filters(self):
"""The number of filters expected to be in the head classifer layer."""
return self._head_filters
@property
def conv_type(self):
"""The expected convolution type (see __init__ for more details)."""
return self._conv_type
def get_config(self): def get_config(self):
config_dict = { config_dict = {
...@@ -495,11 +673,6 @@ class Movinet(tf.keras.Model): ...@@ -495,11 +673,6 @@ class Movinet(tf.keras.Model):
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
return cls(**config) return cls(**config)
@property
def output_specs(self):
"""A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs
@factory.register_backbone_builder('movinet') @factory.register_backbone_builder('movinet')
def build_movinet( def build_movinet(
...@@ -508,8 +681,6 @@ def build_movinet( ...@@ -508,8 +681,6 @@ def build_movinet(
norm_activation_config: hyperparams.Config, norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds MoViNet backbone from a config.""" """Builds MoViNet backbone from a config."""
l2_regularizer = l2_regularizer or tf.keras.regularizers.L2(1.5e-5)
backbone_type = backbone_config.type backbone_type = backbone_config.type
backbone_cfg = backbone_config.get() backbone_cfg = backbone_config.get()
assert backbone_type == 'movinet', ('Inconsistent backbone type ' assert backbone_type == 'movinet', ('Inconsistent backbone type '
...@@ -526,4 +697,5 @@ def build_movinet( ...@@ -526,4 +697,5 @@ def build_movinet(
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer, kernel_regularizer=l2_regularizer,
stochastic_depth_drop_rate=backbone_cfg.stochastic_depth_drop_rate) stochastic_depth_drop_rate=backbone_cfg.stochastic_depth_drop_rate,
use_external_states=backbone_cfg.use_external_states)
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
Reference: https://arxiv.org/pdf/2103.11511.pdf Reference: https://arxiv.org/pdf/2103.11511.pdf
""" """
from typing import Any, Optional, Sequence, Tuple, Union, Dict from typing import Any, Mapping, Optional, Sequence, Tuple, Union
import tensorflow as tf import tensorflow as tf
...@@ -270,7 +270,6 @@ class ConvBlock(tf.keras.layers.Layer): ...@@ -270,7 +270,6 @@ class ConvBlock(tf.keras.layers.Layer):
batch_norm_epsilon: float = 1e-3, batch_norm_epsilon: float = 1e-3,
activation: Optional[Any] = None, activation: Optional[Any] = None,
conv_type: str = '3d', conv_type: str = '3d',
use_positional_encoding: bool = False,
use_buffered_input: bool = False, use_buffered_input: bool = False,
**kwargs): **kwargs):
"""Initializes a conv block. """Initializes a conv block.
...@@ -293,9 +292,6 @@ class ConvBlock(tf.keras.layers.Layer): ...@@ -293,9 +292,6 @@ class ConvBlock(tf.keras.layers.Layer):
ops. '2plus1d' split any 3D ops into two sequential 2D ops with their ops. '2plus1d' split any 3D ops into two sequential 2D ops with their
own batch norm and activation. '3d_2plus1d' is like '2plus1d', but own batch norm and activation. '3d_2plus1d' is like '2plus1d', but
uses two sequential 3D ops instead. uses two sequential 3D ops instead.
use_positional_encoding: add a positional encoding before the temporal
convolution. Assumes `kernel_size[0] > 1`. Otherwise, this argument
is ignored.
use_buffered_input: if True, the input is expected to be padded use_buffered_input: if True, the input is expected to be padded
beforehand. In effect, calling this layer will use 'valid' padding on beforehand. In effect, calling this layer will use 'valid' padding on
the temporal dimension to simulate 'causal' padding. the temporal dimension to simulate 'causal' padding.
...@@ -324,7 +320,6 @@ class ConvBlock(tf.keras.layers.Layer): ...@@ -324,7 +320,6 @@ class ConvBlock(tf.keras.layers.Layer):
self._batch_norm_epsilon = batch_norm_epsilon self._batch_norm_epsilon = batch_norm_epsilon
self._activation = activation self._activation = activation
self._conv_type = conv_type self._conv_type = conv_type
self._use_positional_encoding = use_positional_encoding
self._use_buffered_input = use_buffered_input self._use_buffered_input = use_buffered_input
if activation is not None: if activation is not None:
...@@ -350,7 +345,6 @@ class ConvBlock(tf.keras.layers.Layer): ...@@ -350,7 +345,6 @@ class ConvBlock(tf.keras.layers.Layer):
'batch_norm_epsilon': self._batch_norm_epsilon, 'batch_norm_epsilon': self._batch_norm_epsilon,
'activation': self._activation, 'activation': self._activation,
'conv_type': self._conv_type, 'conv_type': self._conv_type,
'use_positional_encoding': self._use_positional_encoding,
'use_buffered_input': self._use_buffered_input, 'use_buffered_input': self._use_buffered_input,
} }
base_config = super(ConvBlock, self).get_config() base_config = super(ConvBlock, self).get_config()
...@@ -426,11 +420,6 @@ class ConvBlock(tf.keras.layers.Layer): ...@@ -426,11 +420,6 @@ class ConvBlock(tf.keras.layers.Layer):
use_buffered_input=self._use_buffered_input, use_buffered_input=self._use_buffered_input,
name='conv3d') name='conv3d')
if self._use_positional_encoding and self._kernel_size[0] > 1:
self._pos_encoding = nn_layers.PositionalEncoding()
else:
self._pos_encoding = None
self._batch_norm = None self._batch_norm = None
self._batch_norm_temporal = None self._batch_norm_temporal = None
...@@ -451,9 +440,6 @@ class ConvBlock(tf.keras.layers.Layer): ...@@ -451,9 +440,6 @@ class ConvBlock(tf.keras.layers.Layer):
"""Calls the layer with the given inputs.""" """Calls the layer with the given inputs."""
x = inputs x = inputs
if self._pos_encoding is not None and self._conv_temporal is None:
x = self._pos_encoding(x)
x = self._conv(x) x = self._conv(x)
if self._batch_norm is not None: if self._batch_norm is not None:
x = self._batch_norm(x) x = self._batch_norm(x)
...@@ -461,9 +447,6 @@ class ConvBlock(tf.keras.layers.Layer): ...@@ -461,9 +447,6 @@ class ConvBlock(tf.keras.layers.Layer):
x = self._activation_layer(x) x = self._activation_layer(x)
if self._conv_temporal is not None: if self._conv_temporal is not None:
if self._pos_encoding is not None:
x = self._pos_encoding(x)
x = self._conv_temporal(x) x = self._conv_temporal(x)
if self._batch_norm_temporal is not None: if self._batch_norm_temporal is not None:
x = self._batch_norm_temporal(x) x = self._batch_norm_temporal(x)
...@@ -477,11 +460,15 @@ class ConvBlock(tf.keras.layers.Layer): ...@@ -477,11 +460,15 @@ class ConvBlock(tf.keras.layers.Layer):
class StreamBuffer(tf.keras.layers.Layer): class StreamBuffer(tf.keras.layers.Layer):
"""Stream buffer wrapper which caches activations of previous frames.""" """Stream buffer wrapper which caches activations of previous frames."""
def __init__(self, buffer_size: int, **kwargs): def __init__(self,
buffer_size: int,
state_prefix: Optional[str] = None,
**kwargs):
"""Initializes a stream buffer. """Initializes a stream buffer.
Args: Args:
buffer_size: the number of input frames to cache. buffer_size: the number of input frames to cache.
state_prefix: a prefix string to identify states.
**kwargs: keyword arguments to be passed to this layer. **kwargs: keyword arguments to be passed to this layer.
Returns: Returns:
...@@ -489,36 +476,32 @@ class StreamBuffer(tf.keras.layers.Layer): ...@@ -489,36 +476,32 @@ class StreamBuffer(tf.keras.layers.Layer):
""" """
super(StreamBuffer, self).__init__(**kwargs) super(StreamBuffer, self).__init__(**kwargs)
state_prefix = state_prefix if state_prefix is not None else ''
self._state_prefix = state_prefix
self._state_name = f'{state_prefix}/stream_buffer'
self._buffer_size = buffer_size self._buffer_size = buffer_size
def build(self, input_shape):
"""Builds the layer with the given input shape."""
# Here we define strings that will uniquely reference the buffer states
# in the TF graph. These will be used for passing in a mapping of states
# for streaming mode. To do this, we can use a name scope.
with tf.name_scope('buffer') as state_name:
self._state_name = state_name
super(StreamBuffer, self).build(input_shape)
def get_config(self): def get_config(self):
"""Returns a dictionary containing the config used for initialization.""" """Returns a dictionary containing the config used for initialization."""
config = { config = {
'buffer_size': self._buffer_size, 'buffer_size': self._buffer_size,
'state_prefix': self._state_prefix,
} }
base_config = super(StreamBuffer, self).get_config() base_config = super(StreamBuffer, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call(self, def call(
inputs: tf.Tensor, self,
states: Optional[nn_layers.States] = None inputs: tf.Tensor,
) -> Tuple[Any, nn_layers.States]: states: Optional[nn_layers.States] = None,
) -> Tuple[Any, nn_layers.States]:
"""Calls the layer with the given inputs. """Calls the layer with the given inputs.
Args: Args:
inputs: the input tensor. inputs: the input tensor.
states: a dict of states such that, if any of the keys match for this states: a dict of states such that, if any of the keys match for this
layer, will overwrite the contents of the buffer(s). layer, will overwrite the contents of the buffer(s).
Expected keys include `state_prefix + '/stream_buffer'`.
Returns: Returns:
the output tensor and states the output tensor and states
...@@ -526,12 +509,16 @@ class StreamBuffer(tf.keras.layers.Layer): ...@@ -526,12 +509,16 @@ class StreamBuffer(tf.keras.layers.Layer):
states = dict(states) if states is not None else {} states = dict(states) if states is not None else {}
buffer = states.get(self._state_name, None) buffer = states.get(self._state_name, None)
# `tf.pad` has limited support for tf lite, so use tf.concat instead # Create the buffer if it does not exist in the states.
# Output buffer shape:
# [batch_size, buffer_size, input_height, input_width, num_channels]
if buffer is None: if buffer is None:
shape = tf.shape(inputs) shape = tf.shape(inputs)
buffer = tf.zeros( buffer = tf.zeros(
[shape[0], self._buffer_size, shape[2], shape[3], shape[4]], [shape[0], self._buffer_size, shape[2], shape[3], shape[4]],
dtype=inputs.dtype) dtype=inputs.dtype)
# tf.pad has limited support for tf lite, so use tf.concat instead.
full_inputs = tf.concat([buffer, inputs], axis=1) full_inputs = tf.concat([buffer, inputs], axis=1)
# Cache the last b frames of the input where b is the buffer size and f # Cache the last b frames of the input where b is the buffer size and f
...@@ -557,16 +544,16 @@ class StreamConvBlock(ConvBlock): ...@@ -557,16 +544,16 @@ class StreamConvBlock(ConvBlock):
causal: bool = False, causal: bool = False,
use_bias: bool = False, use_bias: bool = False,
kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal', kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras
tf.keras.regularizers.L2(KERNEL_WEIGHT_DECAY), .regularizers.L2(KERNEL_WEIGHT_DECAY),
use_batch_norm: bool = True, use_batch_norm: bool = True,
batch_norm_layer: tf.keras.layers.Layer = batch_norm_layer: tf.keras.layers.Layer = tf.keras.layers.experimental
tf.keras.layers.experimental.SyncBatchNormalization, .SyncBatchNormalization,
batch_norm_momentum: float = 0.99, batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3, batch_norm_epsilon: float = 1e-3,
activation: Optional[Any] = None, activation: Optional[Any] = None,
conv_type: str = '3d', conv_type: str = '3d',
use_positional_encoding: bool = False, state_prefix: Optional[str] = None,
**kwargs): **kwargs):
"""Initializes a stream conv block. """Initializes a stream conv block.
...@@ -588,7 +575,7 @@ class StreamConvBlock(ConvBlock): ...@@ -588,7 +575,7 @@ class StreamConvBlock(ConvBlock):
ops. '2plus1d' split any 3D ops into two sequential 2D ops with their ops. '2plus1d' split any 3D ops into two sequential 2D ops with their
own batch norm and activation. '3d_2plus1d' is like '2plus1d', but own batch norm and activation. '3d_2plus1d' is like '2plus1d', but
uses two sequential 3D ops instead. uses two sequential 3D ops instead.
use_positional_encoding: add a positional encoding before the convolution. state_prefix: a prefix string to identify states.
**kwargs: keyword arguments to be passed to this layer. **kwargs: keyword arguments to be passed to this layer.
Returns: Returns:
...@@ -598,6 +585,8 @@ class StreamConvBlock(ConvBlock): ...@@ -598,6 +585,8 @@ class StreamConvBlock(ConvBlock):
buffer_size = kernel_size[0] - 1 buffer_size = kernel_size[0] - 1
use_buffer = buffer_size > 0 and causal use_buffer = buffer_size > 0 and causal
self._state_prefix = state_prefix
super(StreamConvBlock, self).__init__( super(StreamConvBlock, self).__init__(
filters, filters,
kernel_size, kernel_size,
...@@ -613,18 +602,17 @@ class StreamConvBlock(ConvBlock): ...@@ -613,18 +602,17 @@ class StreamConvBlock(ConvBlock):
batch_norm_epsilon=batch_norm_epsilon, batch_norm_epsilon=batch_norm_epsilon,
activation=activation, activation=activation,
conv_type=conv_type, conv_type=conv_type,
use_positional_encoding=use_positional_encoding,
use_buffered_input=use_buffer, use_buffered_input=use_buffer,
**kwargs) **kwargs)
self._stream_buffer = None self._stream_buffer = None
if use_buffer: if use_buffer:
self._stream_buffer = StreamBuffer( self._stream_buffer = StreamBuffer(
buffer_size=buffer_size) buffer_size=buffer_size, state_prefix=state_prefix)
def get_config(self): def get_config(self):
"""Returns a dictionary containing the config used for initialization.""" """Returns a dictionary containing the config used for initialization."""
config = {} config = {'state_prefix': self._state_prefix}
base_config = super(StreamConvBlock, self).get_config() base_config = super(StreamConvBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
...@@ -645,9 +633,28 @@ class StreamConvBlock(ConvBlock): ...@@ -645,9 +633,28 @@ class StreamConvBlock(ConvBlock):
states = dict(states) if states is not None else {} states = dict(states) if states is not None else {}
x = inputs x = inputs
if self._stream_buffer is not None:
# If we have no separate temporal conv, use the buffer before the 3D conv.
if self._conv_temporal is None and self._stream_buffer is not None:
x, states = self._stream_buffer(x, states=states) x, states = self._stream_buffer(x, states=states)
x = super(StreamConvBlock, self).call(x)
x = self._conv(x)
if self._batch_norm is not None:
x = self._batch_norm(x)
if self._activation_layer is not None:
x = self._activation_layer(x)
if self._conv_temporal is not None:
if self._stream_buffer is not None:
# If we have a separate temporal conv, use the buffer before the
# 1D conv instead (otherwise, we may waste computation on the 2D conv).
x, states = self._stream_buffer(x, states=states)
x = self._conv_temporal(x)
if self._batch_norm_temporal is not None:
x = self._batch_norm_temporal(x)
if self._activation_layer is not None:
x = self._activation_layer(x)
return x, states return x, states
...@@ -667,9 +674,10 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer): ...@@ -667,9 +674,10 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
causal: bool = False, causal: bool = False,
conv_type: str = '3d', conv_type: str = '3d',
kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal', kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras
tf.keras.regularizers.L2(KERNEL_WEIGHT_DECAY), .regularizers.L2(KERNEL_WEIGHT_DECAY),
use_positional_encoding: bool = False, use_positional_encoding: bool = False,
state_prefix: Optional[str] = None,
**kwargs): **kwargs):
"""Implementation for squeeze and excitation. """Implementation for squeeze and excitation.
...@@ -686,6 +694,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer): ...@@ -686,6 +694,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
kernel_regularizer: kernel regularizer for the conv operation. kernel_regularizer: kernel regularizer for the conv operation.
use_positional_encoding: add a positional encoding after the (cumulative) use_positional_encoding: add a positional encoding after the (cumulative)
global average pooling layer. global average pooling layer.
state_prefix: a prefix string to identify states.
**kwargs: keyword arguments to be passed to this layer. **kwargs: keyword arguments to be passed to this layer.
""" """
super(StreamSqueezeExcitation, self).__init__(**kwargs) super(StreamSqueezeExcitation, self).__init__(**kwargs)
...@@ -698,13 +707,15 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer): ...@@ -698,13 +707,15 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._use_positional_encoding = use_positional_encoding self._use_positional_encoding = use_positional_encoding
self._state_prefix = state_prefix
self._pool = nn_layers.GlobalAveragePool3D(keepdims=True, causal=causal) self._pool = nn_layers.GlobalAveragePool3D(
keepdims=True, causal=causal, state_prefix=state_prefix)
self._pos_encoding = None
if use_positional_encoding: if use_positional_encoding:
self._pos_encoding = nn_layers.PositionalEncoding() self._pos_encoding = nn_layers.PositionalEncoding(
else: initializer='zeros', state_prefix=state_prefix)
self._pos_encoding = None
def get_config(self): def get_config(self):
"""Returns a dictionary containing the config used for initialization.""" """Returns a dictionary containing the config used for initialization."""
...@@ -717,6 +728,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer): ...@@ -717,6 +728,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
'kernel_initializer': self._kernel_initializer, 'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer, 'kernel_regularizer': self._kernel_regularizer,
'use_positional_encoding': self._use_positional_encoding, 'use_positional_encoding': self._use_positional_encoding,
'state_prefix': self._state_prefix,
} }
base_config = super(StreamSqueezeExcitation, self).get_config() base_config = super(StreamSqueezeExcitation, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
...@@ -768,7 +780,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer): ...@@ -768,7 +780,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
x, states = self._pool(inputs, states=states) x, states = self._pool(inputs, states=states)
if self._pos_encoding is not None: if self._pos_encoding is not None:
x = self._pos_encoding(x) x, states = self._pos_encoding(x, states=states)
x = self._se_reduce(x) x = self._se_reduce(x)
x = self._se_expand(x) x = self._se_expand(x)
...@@ -992,12 +1004,13 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -992,12 +1004,13 @@ class MovinetBlock(tf.keras.layers.Layer):
conv_type: str = '3d', conv_type: str = '3d',
use_positional_encoding: bool = False, use_positional_encoding: bool = False,
kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal', kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras
tf.keras.regularizers.L2(KERNEL_WEIGHT_DECAY), .regularizers.L2(KERNEL_WEIGHT_DECAY),
batch_norm_layer: tf.keras.layers.Layer = batch_norm_layer: tf.keras.layers.Layer = tf.keras.layers.experimental
tf.keras.layers.experimental.SyncBatchNormalization, .SyncBatchNormalization,
batch_norm_momentum: float = 0.99, batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3, batch_norm_epsilon: float = 1e-3,
state_prefix: Optional[str] = None,
**kwargs): **kwargs):
"""Implementation for MoViNet block. """Implementation for MoViNet block.
...@@ -1021,6 +1034,7 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -1021,6 +1034,7 @@ class MovinetBlock(tf.keras.layers.Layer):
batch_norm_layer: class to use for batch norm. batch_norm_layer: class to use for batch norm.
batch_norm_momentum: momentum of the batch norm operation. batch_norm_momentum: momentum of the batch norm operation.
batch_norm_epsilon: epsilon of the batch norm operation. batch_norm_epsilon: epsilon of the batch norm operation.
state_prefix: a prefix string to identify states.
**kwargs: keyword arguments to be passed to this layer. **kwargs: keyword arguments to be passed to this layer.
""" """
super(MovinetBlock, self).__init__(**kwargs) super(MovinetBlock, self).__init__(**kwargs)
...@@ -1045,6 +1059,7 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -1045,6 +1059,7 @@ class MovinetBlock(tf.keras.layers.Layer):
self._batch_norm_layer = batch_norm_layer self._batch_norm_layer = batch_norm_layer
self._batch_norm_momentum = batch_norm_momentum self._batch_norm_momentum = batch_norm_momentum
self._batch_norm_epsilon = batch_norm_epsilon self._batch_norm_epsilon = batch_norm_epsilon
self._state_prefix = state_prefix
self._expansion = ConvBlock( self._expansion = ConvBlock(
expand_filters, expand_filters,
...@@ -1066,15 +1081,14 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -1066,15 +1081,14 @@ class MovinetBlock(tf.keras.layers.Layer):
causal=self._causal, causal=self._causal,
activation=activation, activation=activation,
conv_type=conv_type, conv_type=conv_type,
use_positional_encoding=use_positional_encoding,
kernel_initializer=kernel_initializer, kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer, kernel_regularizer=kernel_regularizer,
use_batch_norm=True, use_batch_norm=True,
batch_norm_layer=self._batch_norm_layer, batch_norm_layer=self._batch_norm_layer,
batch_norm_momentum=self._batch_norm_momentum, batch_norm_momentum=self._batch_norm_momentum,
batch_norm_epsilon=self._batch_norm_epsilon, batch_norm_epsilon=self._batch_norm_epsilon,
state_prefix=state_prefix,
name='feature') name='feature')
self._projection = ConvBlock( self._projection = ConvBlock(
out_filters, out_filters,
(1, 1, 1), (1, 1, 1),
...@@ -1095,6 +1109,7 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -1095,6 +1109,7 @@ class MovinetBlock(tf.keras.layers.Layer):
use_positional_encoding=use_positional_encoding, use_positional_encoding=use_positional_encoding,
kernel_initializer=kernel_initializer, kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer, kernel_regularizer=kernel_regularizer,
state_prefix=state_prefix,
name='se') name='se')
def get_config(self): def get_config(self):
...@@ -1114,6 +1129,7 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -1114,6 +1129,7 @@ class MovinetBlock(tf.keras.layers.Layer):
'kernel_regularizer': self._kernel_regularizer, 'kernel_regularizer': self._kernel_regularizer,
'batch_norm_momentum': self._batch_norm_momentum, 'batch_norm_momentum': self._batch_norm_momentum,
'batch_norm_epsilon': self._batch_norm_epsilon, 'batch_norm_epsilon': self._batch_norm_epsilon,
'state_prefix': self._state_prefix,
} }
base_config = super(MovinetBlock, self).get_config() base_config = super(MovinetBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
...@@ -1176,12 +1192,13 @@ class Stem(tf.keras.layers.Layer): ...@@ -1176,12 +1192,13 @@ class Stem(tf.keras.layers.Layer):
conv_type: str = '3d', conv_type: str = '3d',
activation: nn_layers.Activation = 'swish', activation: nn_layers.Activation = 'swish',
kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal', kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras
tf.keras.regularizers.L2(KERNEL_WEIGHT_DECAY), .regularizers.L2(KERNEL_WEIGHT_DECAY),
batch_norm_layer: tf.keras.layers.Layer = batch_norm_layer: tf.keras.layers.Layer = tf.keras.layers.experimental
tf.keras.layers.experimental.SyncBatchNormalization, .SyncBatchNormalization,
batch_norm_momentum: float = 0.99, batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3, batch_norm_epsilon: float = 1e-3,
state_prefix: Optional[str] = None,
**kwargs): **kwargs):
"""Implementation for video model stem. """Implementation for video model stem.
...@@ -1200,35 +1217,38 @@ class Stem(tf.keras.layers.Layer): ...@@ -1200,35 +1217,38 @@ class Stem(tf.keras.layers.Layer):
batch_norm_layer: class to use for batch norm. batch_norm_layer: class to use for batch norm.
batch_norm_momentum: momentum of the batch norm operation. batch_norm_momentum: momentum of the batch norm operation.
batch_norm_epsilon: epsilon of the batch norm operation. batch_norm_epsilon: epsilon of the batch norm operation.
state_prefix: a prefix string to identify states.
**kwargs: keyword arguments to be passed to this layer. **kwargs: keyword arguments to be passed to this layer.
""" """
super(Stem, self).__init__(**kwargs) super(Stem, self).__init__(**kwargs)
self._out_filters = out_filters
self._kernel_size = normalize_tuple(kernel_size, 3, 'kernel_size') self._kernel_size = normalize_tuple(kernel_size, 3, 'kernel_size')
self._strides = normalize_tuple(strides, 3, 'strides') self._strides = normalize_tuple(strides, 3, 'strides')
self._out_filters = out_filters
self._conv_type = conv_type
self._causal = causal self._causal = causal
self._conv_type = conv_type
self._activation = activation
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._batch_norm_layer = batch_norm_layer self._batch_norm_layer = batch_norm_layer
self._batch_norm_momentum = batch_norm_momentum self._batch_norm_momentum = batch_norm_momentum
self._batch_norm_epsilon = batch_norm_epsilon self._batch_norm_epsilon = batch_norm_epsilon
self._state_prefix = state_prefix
self._stem = StreamConvBlock( self._stem = StreamConvBlock(
filters=self._out_filters, filters=self._out_filters,
kernel_size=self._kernel_size, kernel_size=self._kernel_size,
strides=self._strides, strides=self._strides,
causal=self._causal, causal=self._causal,
activation=activation, activation=self._activation,
conv_type=self._conv_type, conv_type=self._conv_type,
kernel_initializer=kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
use_batch_norm=True, use_batch_norm=True,
batch_norm_layer=self._batch_norm_layer, batch_norm_layer=self._batch_norm_layer,
batch_norm_momentum=self._batch_norm_momentum, batch_norm_momentum=self._batch_norm_momentum,
batch_norm_epsilon=self._batch_norm_epsilon, batch_norm_epsilon=self._batch_norm_epsilon,
state_prefix=self._state_prefix,
name='stem') name='stem')
def get_config(self): def get_config(self):
...@@ -1238,11 +1258,13 @@ class Stem(tf.keras.layers.Layer): ...@@ -1238,11 +1258,13 @@ class Stem(tf.keras.layers.Layer):
'kernel_size': self._kernel_size, 'kernel_size': self._kernel_size,
'strides': self._strides, 'strides': self._strides,
'causal': self._causal, 'causal': self._causal,
'activation': self._activation,
'conv_type': self._conv_type, 'conv_type': self._conv_type,
'kernel_initializer': self._kernel_initializer, 'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer, 'kernel_regularizer': self._kernel_regularizer,
'batch_norm_momentum': self._batch_norm_momentum, 'batch_norm_momentum': self._batch_norm_momentum,
'batch_norm_epsilon': self._batch_norm_epsilon, 'batch_norm_epsilon': self._batch_norm_epsilon,
'state_prefix': self._state_prefix,
} }
base_config = super(Stem, self).get_config() base_config = super(Stem, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
...@@ -1278,12 +1300,13 @@ class Head(tf.keras.layers.Layer): ...@@ -1278,12 +1300,13 @@ class Head(tf.keras.layers.Layer):
conv_type: str = '3d', conv_type: str = '3d',
activation: nn_layers.Activation = 'swish', activation: nn_layers.Activation = 'swish',
kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal', kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras
tf.keras.regularizers.L2(KERNEL_WEIGHT_DECAY), .regularizers.L2(KERNEL_WEIGHT_DECAY),
batch_norm_layer: tf.keras.layers.Layer = batch_norm_layer: tf.keras.layers.Layer = tf.keras.layers.experimental
tf.keras.layers.experimental.SyncBatchNormalization, .SyncBatchNormalization,
batch_norm_momentum: float = 0.99, batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3, batch_norm_epsilon: float = 1e-3,
state_prefix: Optional[str] = None,
**kwargs): **kwargs):
"""Implementation for video model head. """Implementation for video model head.
...@@ -1299,17 +1322,20 @@ class Head(tf.keras.layers.Layer): ...@@ -1299,17 +1322,20 @@ class Head(tf.keras.layers.Layer):
batch_norm_layer: class to use for batch norm. batch_norm_layer: class to use for batch norm.
batch_norm_momentum: momentum of the batch norm operation. batch_norm_momentum: momentum of the batch norm operation.
batch_norm_epsilon: epsilon of the batch norm operation. batch_norm_epsilon: epsilon of the batch norm operation.
state_prefix: a prefix string to identify states.
**kwargs: keyword arguments to be passed to this layer. **kwargs: keyword arguments to be passed to this layer.
""" """
super(Head, self).__init__(**kwargs) super(Head, self).__init__(**kwargs)
self._project_filters = project_filters self._project_filters = project_filters
self._conv_type = conv_type self._conv_type = conv_type
self._activation = activation
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._batch_norm_layer = batch_norm_layer self._batch_norm_layer = batch_norm_layer
self._batch_norm_momentum = batch_norm_momentum self._batch_norm_momentum = batch_norm_momentum
self._batch_norm_epsilon = batch_norm_epsilon self._batch_norm_epsilon = batch_norm_epsilon
self._state_prefix = state_prefix
self._project = ConvBlock( self._project = ConvBlock(
filters=project_filters, filters=project_filters,
...@@ -1322,25 +1348,29 @@ class Head(tf.keras.layers.Layer): ...@@ -1322,25 +1348,29 @@ class Head(tf.keras.layers.Layer):
batch_norm_momentum=self._batch_norm_momentum, batch_norm_momentum=self._batch_norm_momentum,
batch_norm_epsilon=self._batch_norm_epsilon, batch_norm_epsilon=self._batch_norm_epsilon,
name='project') name='project')
self._pool = nn_layers.GlobalAveragePool3D(keepdims=True, causal=False) self._pool = nn_layers.GlobalAveragePool3D(
keepdims=True, causal=False, state_prefix=state_prefix)
def get_config(self): def get_config(self):
"""Returns a dictionary containing the config used for initialization.""" """Returns a dictionary containing the config used for initialization."""
config = { config = {
'project_filters': self._project_filters, 'project_filters': self._project_filters,
'conv_type': self._conv_type, 'conv_type': self._conv_type,
'activation': self._activation,
'kernel_initializer': self._kernel_initializer, 'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer, 'kernel_regularizer': self._kernel_regularizer,
'batch_norm_momentum': self._batch_norm_momentum, 'batch_norm_momentum': self._batch_norm_momentum,
'batch_norm_epsilon': self._batch_norm_epsilon, 'batch_norm_epsilon': self._batch_norm_epsilon,
'state_prefix': self._state_prefix,
} }
base_config = super(Head, self).get_config() base_config = super(Head, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call(self, def call(
inputs: Union[tf.Tensor, Dict[str, tf.Tensor]], self,
states: Optional[nn_layers.States] = None, inputs: Union[tf.Tensor, Mapping[str, tf.Tensor]],
) -> Tuple[tf.Tensor, nn_layers.States]: states: Optional[nn_layers.States] = None,
) -> Tuple[tf.Tensor, nn_layers.States]:
"""Calls the layer with the given inputs. """Calls the layer with the given inputs.
Args: Args:
......
...@@ -146,7 +146,6 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -146,7 +146,6 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase):
use_bias=False, use_bias=False,
activation='relu', activation='relu',
conv_type='2plus1d', conv_type='2plus1d',
use_positional_encoding=True,
) )
stream_conv_block = movinet_layers.StreamConvBlock( stream_conv_block = movinet_layers.StreamConvBlock(
...@@ -158,7 +157,6 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -158,7 +157,6 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase):
use_bias=False, use_bias=False,
activation='relu', activation='relu',
conv_type='2plus1d', conv_type='2plus1d',
use_positional_encoding=True,
) )
inputs = tf.ones([1, 4, 2, 2, 3]) inputs = tf.ones([1, 4, 2, 2, 3])
...@@ -197,7 +195,6 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -197,7 +195,6 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase):
use_bias=False, use_bias=False,
activation='relu', activation='relu',
conv_type='3d_2plus1d', conv_type='3d_2plus1d',
use_positional_encoding=True,
) )
stream_conv_block = movinet_layers.StreamConvBlock( stream_conv_block = movinet_layers.StreamConvBlock(
...@@ -209,7 +206,6 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -209,7 +206,6 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase):
use_bias=False, use_bias=False,
activation='relu', activation='relu',
conv_type='3d_2plus1d', conv_type='3d_2plus1d',
use_positional_encoding=True,
) )
inputs = tf.ones([1, 4, 2, 2, 3]) inputs = tf.ones([1, 4, 2, 2, 3])
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
Reference: https://arxiv.org/pdf/2103.11511.pdf Reference: https://arxiv.org/pdf/2103.11511.pdf
""" """
from typing import Mapping from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -31,16 +31,17 @@ from official.vision.beta.projects.movinet.modeling import movinet_layers ...@@ -31,16 +31,17 @@ from official.vision.beta.projects.movinet.modeling import movinet_layers
class MovinetClassifier(tf.keras.Model): class MovinetClassifier(tf.keras.Model):
"""A video classification class builder.""" """A video classification class builder."""
def __init__(self, def __init__(
backbone: tf.keras.Model, self,
num_classes: int, backbone: tf.keras.Model,
input_specs: Mapping[str, tf.keras.layers.InputSpec] = None, num_classes: int,
dropout_rate: float = 0.0, input_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None,
kernel_initializer: str = 'HeNormal', dropout_rate: float = 0.0,
kernel_regularizer: tf.keras.regularizers.Regularizer = None, kernel_initializer: str = 'HeNormal',
bias_regularizer: tf.keras.regularizers.Regularizer = None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
output_states: bool = False, bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs): output_states: bool = False,
**kwargs):
"""Movinet initialization function. """Movinet initialization function.
Args: Args:
...@@ -70,47 +71,110 @@ class MovinetClassifier(tf.keras.Model): ...@@ -70,47 +71,110 @@ class MovinetClassifier(tf.keras.Model):
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
self._output_states = output_states self._output_states = output_states
# Keras model variable that excludes @property.setters from tracking state_specs = None
self._self_setattr_tracking = False if backbone.use_external_states:
state_specs = backbone.initial_state_specs(
input_shape=input_specs['image'].shape)
inputs = { inputs, outputs = self._build_network(
name: tf.keras.Input(shape=state.shape[1:], name=f'states/{name}') backbone, input_specs, state_specs=state_specs)
for name, state in input_specs.items()
super(MovinetClassifier, self).__init__(
inputs=inputs, outputs=outputs, **kwargs)
# Move backbone after super() call so Keras is happy
self._backbone = backbone
def _build_network(
self,
backbone: tf.keras.Model,
input_specs: Mapping[str, tf.keras.layers.InputSpec],
state_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None,
) -> Tuple[Mapping[str, tf.keras.Input], Union[Tuple[Mapping[
str, tf.Tensor], Mapping[str, tf.Tensor]], Mapping[str, tf.Tensor]]]:
"""Builds the model network.
Args:
backbone: the model backbone.
input_specs: the model input spec to use.
state_specs: a dict of states such that, if any of the keys match for a
layer, will overwrite the contents of the buffer(s).
Returns:
Inputs and outputs as a tuple. Inputs are expected to be a dict with
base input and states. Outputs are expected to be a dict of endpoints
and (optionally) output states.
"""
state_specs = state_specs if state_specs is not None else {}
states = {
name: tf.keras.Input(shape=spec.shape[1:], dtype=spec.dtype, name=name)
for name, spec in state_specs.items()
} }
states = inputs.get('states', {}) image = tf.keras.Input(shape=input_specs['image'].shape[1:], name='image')
inputs = {**states, 'image': image}
if backbone.use_external_states:
before_states = states
endpoints, states = backbone(inputs)
after_states = states
new_states = set(after_states) - set(before_states)
if new_states:
raise ValueError(
'Expected input and output states to be the same. Got extra states '
'{}, expected {}'.format(new_states, set(before_states)))
mismatched_shapes = {}
for name in after_states:
before_shape = before_states[name].shape
after_shape = after_states[name].shape
if len(before_shape) != len(after_shape):
mismatched_shapes[name] = (before_shape, after_shape)
continue
for before, after in zip(before_shape, after_shape):
if before is not None and after is not None and before != after:
mismatched_shapes[name] = (before_shape, after_shape)
break
if mismatched_shapes:
raise ValueError(
'Got mismatched input and output state shapes: {}'.format(
mismatched_shapes))
else:
endpoints, states = backbone(inputs)
endpoints, states = backbone(dict(image=inputs['image'], states=states))
x = endpoints['head'] x = endpoints['head']
x = movinet_layers.ClassifierHead( x = movinet_layers.ClassifierHead(
head_filters=backbone._head_filters, head_filters=backbone.head_filters,
num_classes=num_classes, num_classes=self._num_classes,
dropout_rate=dropout_rate, dropout_rate=self._dropout_rate,
kernel_initializer=kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
conv_type=backbone._conv_type)(x) conv_type=backbone.conv_type)(
x)
if output_states:
inputs['states'] = {
k: tf.keras.Input(shape=v.shape[1:], name=k)
for k, v in states.items()
}
outputs = (x, states) if output_states else x outputs = (x, states) if self._output_states else x
super(MovinetClassifier, self).__init__( return inputs, outputs
inputs=inputs, outputs=outputs, **kwargs)
# Move backbone after super() call so Keras is happy def initial_state_specs(
self._backbone = backbone self, input_shape: Sequence[int]) -> Dict[str, tf.keras.layers.InputSpec]:
return self._backbone.initial_state_specs(input_shape=input_shape)
@tf.function
def init_states(self, input_shape: Sequence[int]) -> Dict[str, tf.Tensor]:
"""Returns initial states for the first call in steaming mode."""
return self._backbone.init_states(input_shape)
@property @property
def checkpoint_items(self): def checkpoint_items(self) -> Dict[str, Any]:
"""Returns a dictionary of items to be additionally checkpointed.""" """Returns a dictionary of items to be additionally checkpointed."""
return dict(backbone=self.backbone) return dict(backbone=self.backbone)
@property @property
def backbone(self): def backbone(self) -> tf.keras.Model:
"""Returns the backbone of the model."""
return self._backbone return self._backbone
def get_config(self): def get_config(self):
...@@ -141,10 +205,10 @@ class MovinetClassifier(tf.keras.Model): ...@@ -141,10 +205,10 @@ class MovinetClassifier(tf.keras.Model):
@model_factory.register_model_builder('movinet') @model_factory.register_model_builder('movinet')
def build_movinet_model( def build_movinet_model(
input_specs: tf.keras.layers.InputSpec, input_specs: Mapping[str, tf.keras.layers.InputSpec],
model_config: cfg.MovinetModel, model_config: cfg.MovinetModel,
num_classes: int, num_classes: int,
l2_regularizer: tf.keras.regularizers.Regularizer = None): l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None):
"""Builds movinet model.""" """Builds movinet model."""
logging.info('Building movinet model with num classes: %s', num_classes) logging.info('Building movinet model with num classes: %s', num_classes)
if l2_regularizer is not None: if l2_regularizer is not None:
......
...@@ -48,28 +48,85 @@ class MovinetModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -48,28 +48,85 @@ class MovinetModelTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual([2, num_classes], logits.shape) self.assertAllEqual([2, num_classes], logits.shape)
def test_movinet_classifier_stream(self): def test_movinet_classifier_stream(self):
"""Test if the classifier can be run in streaming mode."""
tf.keras.backend.set_image_data_format('channels_last') tf.keras.backend.set_image_data_format('channels_last')
model = movinet.Movinet( backbone = movinet.Movinet(
model_id='a0', model_id='a0',
causal=True, causal=True,
use_external_states=True,
) )
inputs = tf.ones([1, 5, 128, 128, 3]) model = movinet_model.MovinetClassifier(
backbone, num_classes=600, output_states=True)
inputs = tf.ones([1, 8, 172, 172, 3])
init_states = model.init_states(tf.shape(inputs))
expected, _ = model({**init_states, 'image': inputs})
frames = tf.split(inputs, inputs.shape[1], axis=1)
states = init_states
for frame in frames:
output, states = model({**states, 'image': frame})
predicted = output
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected, 1e-5, 1e-5)
def test_movinet_classifier_stream_pos_enc(self):
"""Test if the classifier can be run in streaming mode with pos encoding."""
tf.keras.backend.set_image_data_format('channels_last')
backbone = movinet.Movinet(
model_id='a0',
causal=True,
use_external_states=True,
use_positional_encoding=True,
)
model = movinet_model.MovinetClassifier(
backbone, num_classes=600, output_states=True)
inputs = tf.ones([1, 8, 172, 172, 3])
expected_endpoints, _ = model(dict(image=inputs, states={})) init_states = model.init_states(tf.shape(inputs))
expected, _ = model({**init_states, 'image': inputs})
frames = tf.split(inputs, inputs.shape[1], axis=1) frames = tf.split(inputs, inputs.shape[1], axis=1)
output, states = None, {} states = init_states
for frame in frames: for frame in frames:
output, states = model(dict(image=frame, states=states)) output, states = model({**states, 'image': frame})
predicted_endpoints = output predicted = output
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected, 1e-5, 1e-5)
def test_movinet_classifier_stream_pos_enc_2plus1d(self):
"""Test if the model can run in streaming mode with pos encoding, (2+1)D."""
tf.keras.backend.set_image_data_format('channels_last')
backbone = movinet.Movinet(
model_id='a0',
causal=True,
use_external_states=True,
use_positional_encoding=True,
conv_type='2plus1d',
)
model = movinet_model.MovinetClassifier(
backbone, num_classes=600, output_states=True)
predicted = predicted_endpoints['head'] inputs = tf.ones([1, 8, 172, 172, 3])
# The expected final output is simply the mean across frames init_states = model.init_states(tf.shape(inputs))
expected = expected_endpoints['head'] expected, _ = model({**init_states, 'image': inputs})
expected = tf.reduce_mean(expected, 1, keepdims=True)
frames = tf.split(inputs, inputs.shape[1], axis=1)
states = init_states
for frame in frames:
output, states = model({**states, 'image': frame})
predicted = output
self.assertEqual(predicted.shape, expected.shape) self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected, 1e-5, 1e-5) self.assertAllClose(predicted, expected, 1e-5, 1e-5)
......
...@@ -48,14 +48,15 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -48,14 +48,15 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
"""Test creation of MoViNet family models with states.""" """Test creation of MoViNet family models with states."""
tf.keras.backend.set_image_data_format('channels_last') tf.keras.backend.set_image_data_format('channels_last')
network = movinet.Movinet( backbone = movinet.Movinet(
model_id='a0', model_id='a0',
causal=True, causal=True,
use_external_states=True,
) )
inputs = tf.ones([1, 8, 128, 128, 3]) inputs = tf.ones([1, 8, 128, 128, 3])
_, states = network(inputs) init_states = backbone.init_states(tf.shape(inputs))
endpoints, new_states = network(dict(image=inputs, states=states)) endpoints, new_states = backbone({**init_states, 'image': inputs})
self.assertAllEqual(endpoints['stem'].shape, [1, 8, 64, 64, 8]) self.assertAllEqual(endpoints['stem'].shape, [1, 8, 64, 64, 8])
self.assertAllEqual(endpoints['b0/l0'].shape, [1, 8, 32, 32, 8]) self.assertAllEqual(endpoints['b0/l0'].shape, [1, 8, 32, 32, 8])
...@@ -65,25 +66,28 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -65,25 +66,28 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual(endpoints['b4/l0'].shape, [1, 8, 4, 4, 104]) self.assertAllEqual(endpoints['b4/l0'].shape, [1, 8, 4, 4, 104])
self.assertAllEqual(endpoints['head'].shape, [1, 1, 1, 1, 480]) self.assertAllEqual(endpoints['head'].shape, [1, 1, 1, 1, 480])
self.assertNotEmpty(states) self.assertNotEmpty(init_states)
self.assertNotEmpty(new_states) self.assertNotEmpty(new_states)
def test_movinet_stream(self): def test_movinet_stream(self):
"""Test if the backbone can be run in streaming mode."""
tf.keras.backend.set_image_data_format('channels_last') tf.keras.backend.set_image_data_format('channels_last')
model = movinet.Movinet( backbone = movinet.Movinet(
model_id='a0', model_id='a0',
causal=True, causal=True,
use_external_states=True,
) )
inputs = tf.ones([1, 5, 128, 128, 3]) inputs = tf.ones([1, 5, 128, 128, 3])
expected_endpoints, _ = model(dict(image=inputs, states={})) init_states = backbone.init_states(tf.shape(inputs))
expected_endpoints, _ = backbone({**init_states, 'image': inputs})
frames = tf.split(inputs, inputs.shape[1], axis=1) frames = tf.split(inputs, inputs.shape[1], axis=1)
output, states = None, {} states = init_states
for frame in frames: for frame in frames:
output, states = model(dict(image=frame, states=states)) output, states = backbone({**states, 'image': frame})
predicted_endpoints = output predicted_endpoints = output
predicted = predicted_endpoints['head'] predicted = predicted_endpoints['head']
...@@ -98,20 +102,22 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -98,20 +102,22 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
def test_movinet_2plus1d_stream(self): def test_movinet_2plus1d_stream(self):
tf.keras.backend.set_image_data_format('channels_last') tf.keras.backend.set_image_data_format('channels_last')
model = movinet.Movinet( backbone = movinet.Movinet(
model_id='a0', model_id='a0',
causal=True, causal=True,
conv_type='2plus1d', conv_type='2plus1d',
use_external_states=True,
) )
inputs = tf.ones([1, 5, 128, 128, 3]) inputs = tf.ones([1, 5, 128, 128, 3])
expected_endpoints, _ = model(dict(image=inputs, states={})) init_states = backbone.init_states(tf.shape(inputs))
expected_endpoints, _ = backbone({**init_states, 'image': inputs})
frames = tf.split(inputs, inputs.shape[1], axis=1) frames = tf.split(inputs, inputs.shape[1], axis=1)
output, states = None, {} states = init_states
for frame in frames: for frame in frames:
output, states = model(dict(image=frame, states=states)) output, states = backbone({**states, 'image': frame})
predicted_endpoints = output predicted_endpoints = output
predicted = predicted_endpoints['head'] predicted = predicted_endpoints['head']
...@@ -126,20 +132,22 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -126,20 +132,22 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
def test_movinet_3d_2plus1d_stream(self): def test_movinet_3d_2plus1d_stream(self):
tf.keras.backend.set_image_data_format('channels_last') tf.keras.backend.set_image_data_format('channels_last')
model = movinet.Movinet( backbone = movinet.Movinet(
model_id='a0', model_id='a0',
causal=True, causal=True,
conv_type='3d_2plus1d', conv_type='3d_2plus1d',
use_external_states=True,
) )
inputs = tf.ones([1, 5, 128, 128, 3]) inputs = tf.ones([1, 5, 128, 128, 3])
expected_endpoints, _ = model(dict(image=inputs, states={})) init_states = backbone.init_states(tf.shape(inputs))
expected_endpoints, _ = backbone({**init_states, 'image': inputs})
frames = tf.split(inputs, inputs.shape[1], axis=1) frames = tf.split(inputs, inputs.shape[1], axis=1)
output, states = None, {} states = init_states
for frame in frames: for frame in frames:
output, states = model(dict(image=frame, states=states)) output, states = backbone({**states, 'image': frame})
predicted_endpoints = output predicted_endpoints = output
predicted = predicted_endpoints['head'] predicted = predicted_endpoints['head']
...@@ -157,6 +165,7 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -157,6 +165,7 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
model_id='a0', model_id='a0',
causal=True, causal=True,
use_positional_encoding=True, use_positional_encoding=True,
use_external_states=True,
) )
network = movinet.Movinet(**kwargs) network = movinet.Movinet(**kwargs)
......
...@@ -72,7 +72,7 @@ trainer: ...@@ -72,7 +72,7 @@ trainer:
type: 'cosine' type: 'cosine'
cosine: cosine:
initial_learning_rate: 0.6 # 0.3 × BatchSize / 256 initial_learning_rate: 0.6 # 0.3 × BatchSize / 256
decay_steps: 43200 # train_steps - warmup_steps decay_steps: 48000
warmup: warmup:
type: 'linear' type: 'linear'
linear: linear:
......
# ImageNet classification. # SimCLR Imagenet 10% finetuning.
runtime: runtime:
distribution_strategy: 'mirrored' distribution_strategy: 'mirrored'
mixed_precision_dtype: 'float16' mixed_precision_dtype: 'float16'
...@@ -55,7 +55,7 @@ trainer: ...@@ -55,7 +55,7 @@ trainer:
train_steps: 12500 # 100 epochs train_steps: 12500 # 100 epochs
validation_steps: 49 # NUM_EXAMPLES (50000) // global_batch_size validation_steps: 49 # NUM_EXAMPLES (50000) // global_batch_size
validation_interval: 125 validation_interval: 125
steps_per_loop: 125 # NUM_EXAMPLES (1281167) // global_batch_size steps_per_loop: 125 # NUM_EXAMPLES (128116) // global_batch_size
summary_interval: 125 summary_interval: 125
checkpoint_interval: 125 checkpoint_interval: 125
optimizer_config: optimizer_config:
......
# SimCLR Imagenet 10% finetuning.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
model:
mode: 'finetune'
input_size: [224, 224, 3]
backbone:
type: 'resnet'
resnet:
model_id: 50
backbone_trainable: true
projection_head:
proj_output_dim: 128
num_proj_layers: 3
ft_proj_idx: 1
supervised_head:
num_classes: 1001
zero_init: true
norm_activation:
use_sync_bn: false
norm_momentum: 0.9
norm_epsilon: 0.00001
loss:
label_smoothing: 0.0
one_hot: true
evaluation:
top_k: 5
one_hot: true
init_checkpoint: gs://tf_model_garden/vision/simclr/r50_1x
init_checkpoint_modules: 'backbone_projection'
train_data:
tfds_name: 'imagenet2012_subset/10pct'
tfds_split: 'train'
input_path: ''
is_training: true
global_batch_size: 1024
dtype: 'bfloat16'
parser:
mode: 'finetune'
validation_data:
tfds_name: 'imagenet2012_subset/10pct'
tfds_split: 'validation'
input_path: ''
is_training: false
global_batch_size: 1024
dtype: 'bfloat16'
drop_remainder: false
parser:
mode: 'finetune'
trainer:
train_steps: 12500 # 100 epochs
validation_steps: 49 # NUM_EXAMPLES (50000) // global_batch_size
validation_interval: 125
steps_per_loop: 125 # NUM_EXAMPLES (128116) // global_batch_size
summary_interval: 125
checkpoint_interval: 125
optimizer_config:
optimizer:
type: 'lars'
lars:
momentum: 0.9
weight_decay_rate: 0.0
exclude_from_weight_decay: ['batch_normalization', 'bias']
learning_rate:
type: 'cosine'
cosine:
initial_learning_rate: 0.04 # 0.01 × BatchSize / 512
decay_steps: 12500 # train_steps
# ImageNet classification. # SimCLR Imagenet pretraining.
runtime: runtime:
distribution_strategy: 'mirrored' distribution_strategy: 'mirrored'
mixed_precision_dtype: 'float16' mixed_precision_dtype: 'float16'
...@@ -49,12 +49,12 @@ task: ...@@ -49,12 +49,12 @@ task:
decoder: decoder:
decode_label: true decode_label: true
trainer: trainer:
train_steps: 187200 # 300 epochs train_steps: 500000 # 800 epochs
validation_steps: 24 # NUM_EXAMPLES (50000) // global_batch_size validation_steps: 24 # NUM_EXAMPLES (50000) // global_batch_size
validation_interval: 624 validation_interval: 625
steps_per_loop: 624 # NUM_EXAMPLES (1281167) // global_batch_size steps_per_loop: 625 # NUM_EXAMPLES (1281167) // global_batch_size
summary_interval: 624 summary_interval: 625
checkpoint_interval: 624 checkpoint_interval: 625
optimizer_config: optimizer_config:
optimizer: optimizer:
type: 'lars' type: 'lars'
...@@ -66,8 +66,8 @@ trainer: ...@@ -66,8 +66,8 @@ trainer:
type: 'cosine' type: 'cosine'
cosine: cosine:
initial_learning_rate: 1.6 # 0.2 * BatchSize / 256 initial_learning_rate: 1.6 # 0.2 * BatchSize / 256
decay_steps: 177840 # train_steps - warmup_steps decay_steps: 500000
warmup: warmup:
type: 'linear' type: 'linear'
linear: linear:
warmup_steps: 9360 # 5% of total epochs warmup_steps: 25000 # 5% of total epochs
# SimCLR Imagenet pretraining.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
model:
mode: 'pretrain'
input_size: [224, 224, 3]
backbone:
type: 'resnet'
resnet:
model_id: 50
backbone_trainable: true
projection_head:
proj_output_dim: 128
num_proj_layers: 3
ft_proj_idx: 0
supervised_head:
num_classes: 1001
norm_activation:
use_sync_bn: true
norm_momentum: 0.9
norm_epsilon: 0.00001
loss:
projection_norm: true
temperature: 0.1
evaluation:
top_k: 5
one_hot: true
train_data:
input_path: '/readahead/200M/placer/prod/home/distbelief/imagenet-tensorflow/imagenet-2012-tfrecord/train*'
is_training: true
global_batch_size: 2048
dtype: 'bfloat16'
parser:
mode: 'pretrain'
decoder:
decode_label: true
validation_data:
input_path: '/readahead/200M/placer/prod/home/distbelief/imagenet-tensorflow/imagenet-2012-tfrecord/valid*'
is_training: false
global_batch_size: 2048
dtype: 'bfloat16'
drop_remainder: false
parser:
mode: 'pretrain'
decoder:
decode_label: true
trainer:
train_steps: 500000 # 800 epochs
validation_steps: 24 # NUM_EXAMPLES (50000) // global_batch_size
validation_interval: 625
steps_per_loop: 625 # NUM_EXAMPLES (1281167) // global_batch_size
summary_interval: 625
checkpoint_interval: 625
optimizer_config:
optimizer:
type: 'lars'
lars:
momentum: 0.9
weight_decay_rate: 0.000001
exclude_from_weight_decay: ['batch_normalization', 'bias']
learning_rate:
type: 'cosine'
cosine:
initial_learning_rate: 1.6 # 0.2 * BatchSize / 256
decay_steps: 500000
warmup:
type: 'linear'
linear:
warmup_steps: 25000 # 5% of total epochs
DISCLAIMER: this YOLO implementation is still under development. No support will
be provided during the development phase.
# YOLO Object Detectors, You Only Look Once # YOLO Object Detectors, You Only Look Once
[![Paper](http://img.shields.io/badge/Paper-arXiv.1804.02767-B3181B?logo=arXiv)](https://arxiv.org/abs/1804.02767) [![Paper](http://img.shields.io/badge/Paper-arXiv.1804.02767-B3181B?logo=arXiv)](https://arxiv.org/abs/1804.02767)
...@@ -76,5 +79,4 @@ connected to a new, more powerful backbone if a person chose to. ...@@ -76,5 +79,4 @@ connected to a new, more powerful backbone if a person chose to.
[![Python 3.8](https://img.shields.io/badge/Python-3.8-3776AB)](https://www.python.org/downloads/release/python-380/) [![Python 3.8](https://img.shields.io/badge/Python-3.8-3776AB)](https://www.python.org/downloads/release/python-380/)
DISCLAIMER: this YOLO implementation is still under development. No support will be provided during the development phase. DISCLAIMER: this YOLO implementation is still under development. No support will be provided during the development phase.
...@@ -19,3 +19,4 @@ from official.common import registry_imports ...@@ -19,3 +19,4 @@ from official.common import registry_imports
from official.vision.beta.projects.yolo.configs import darknet_classification from official.vision.beta.projects.yolo.configs import darknet_classification
from official.vision.beta.projects.yolo.modeling.backbones import darknet from official.vision.beta.projects.yolo.modeling.backbones import darknet
from official.vision.beta.projects.yolo.tasks import image_classification from official.vision.beta.projects.yolo.tasks import image_classification
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment