Commit 78c43ef1 authored by Gunho Park's avatar Gunho Park
Browse files

Merge branch 'master' of https://github.com/tensorflow/models

parents 67cfc95b e3c7e300
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A sample model implementation.
This is only a dummy example to showcase how a model is composed. It is usually
not needed to implement a modedl from scratch. Most SoTA models can be found and
directly used from `official/vision/beta/modeling` directory.
"""
from typing import Any, Mapping
# Import libraries
import tensorflow as tf
from official.vision.beta.projects.example import example_config as example_cfg
@tf.keras.utils.register_keras_serializable(package='Vision')
class ExampleModel(tf.keras.Model):
"""A example model class.
A model is a subclass of tf.keras.Model where layers are built in the
constructor.
"""
def __init__(
self,
num_classes: int,
input_specs: tf.keras.layers.InputSpec = tf.keras.layers.InputSpec(
shape=[None, None, None, 3]),
**kwargs):
"""Initializes the example model.
All layers are defined in the constructor, and config is recorded in the
`_config_dict` object for serialization.
Args:
num_classes: The number of classes in classification task.
input_specs: A `tf.keras.layers.InputSpec` spec of the input tensor.
**kwargs: Additional keyword arguments to be passed.
"""
inputs = tf.keras.Input(shape=input_specs.shape[1:], name=input_specs.name)
outputs = tf.keras.layers.Conv2D(
filters=16, kernel_size=3, strides=2, padding='same', use_bias=False)(
inputs)
outputs = tf.keras.layers.Conv2D(
filters=32, kernel_size=3, strides=2, padding='same', use_bias=False)(
outputs)
outputs = tf.keras.layers.Conv2D(
filters=64, kernel_size=3, strides=2, padding='same', use_bias=False)(
outputs)
outputs = tf.keras.layers.GlobalAveragePooling2D()(outputs)
outputs = tf.keras.layers.Dense(1024, activation='relu')(outputs)
outputs = tf.keras.layers.Dense(num_classes)(outputs)
super().__init__(inputs=inputs, outputs=outputs, **kwargs)
self._input_specs = input_specs
self._config_dict = {'num_classes': num_classes, 'input_specs': input_specs}
def get_config(self) -> Mapping[str, Any]:
"""Gets the config of this model."""
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
"""Constructs an instance of this model from input config."""
return cls(**config)
def build_example_model(input_specs: tf.keras.layers.InputSpec,
model_config: example_cfg.ExampleModel,
**kwargs) -> tf.keras.Model:
"""Builds and returns the example model.
This function is the main entry point to build a model. Commonly, it build a
model by building a backbone, decoder and head. An example of building a
classification model is at
third_party/tensorflow_models/official/vision/beta/modeling/backbones/resnet.py.
However, it is not mandatory for all models to have these three pieces
exactly. Depending on the task, model can be as simple as the example model
here or more complex, such as multi-head architecture.
Args:
input_specs: The specs of the input layer that defines input size.
model_config: The config containing parameters to build a model.
**kwargs: Additional keyword arguments to be passed.
Returns:
A tf.keras.Model object.
"""
return ExampleModel(
num_classes=model_config.num_classes, input_specs=input_specs, **kwargs)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An example task definition for image classification."""
from typing import Any, List, Optional, Tuple, Sequence, Mapping
import tensorflow as tf
from official.common import dataset_fn
from official.core import base_task
from official.core import task_factory
from official.modeling import tf_utils
from official.vision.beta.dataloaders import input_reader_factory
from official.vision.beta.projects.example import example_config as exp_cfg
from official.vision.beta.projects.example import example_input
from official.vision.beta.projects.example import example_model
@task_factory.register_task_cls(exp_cfg.ExampleTask)
class ExampleTask(base_task.Task):
"""Class of an example task.
A task is a subclass of base_task.Task that defines model, input, loss, metric
and one training and evaluation step, etc.
"""
def build_model(self) -> tf.keras.Model:
"""Builds a model."""
input_specs = tf.keras.layers.InputSpec(shape=[None] +
self.task_config.model.input_size)
model = example_model.build_example_model(
input_specs=input_specs, model_config=self.task_config.model)
return model
def build_inputs(
self,
params: exp_cfg.ExampleDataConfig,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Builds input.
The input from this function is a tf.data.Dataset that has gone through
pre-processing steps, such as augmentation, batching, shuffuling, etc.
Args:
params: The experiment config.
input_context: An optional InputContext used by input reader.
Returns:
A tf.data.Dataset object.
"""
num_classes = self.task_config.model.num_classes
input_size = self.task_config.model.input_size
decoder = example_input.Decoder()
parser = example_input.Parser(
output_size=input_size[:2], num_classes=num_classes)
reader = input_reader_factory.input_reader_generator(
params,
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read(input_context=input_context)
return dataset
def build_losses(self,
labels: tf.Tensor,
model_outputs: tf.Tensor,
aux_losses: Optional[Any] = None) -> tf.Tensor:
"""Builds losses for training and validation.
Args:
labels: Input groundtruth labels.
model_outputs: Output of the model.
aux_losses: The auxiliarly loss tensors, i.e. `losses` in tf.keras.Model.
Returns:
The total loss tensor.
"""
total_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, model_outputs, from_logits=True)
total_loss = tf_utils.safe_mean(total_loss)
if aux_losses:
total_loss += tf.add_n(aux_losses)
return total_loss
def build_metrics(self,
training: bool = True) -> Sequence[tf.keras.metrics.Metric]:
"""Gets streaming metrics for training/validation.
This function builds and returns a list of metrics to compute during
training and validation. The list contains objects of subclasses of
tf.keras.metrics.Metric. Training and validation can have different metrics.
Args:
training: Whether the metric is for training or not.
Returns:
A list of tf.keras.metrics.Metric objects.
"""
k = self.task_config.evaluation.top_k
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))
]
return metrics
def train_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[List[Any]] = None) -> Mapping[str, Any]:
"""Does forward and backward.
This example assumes input is a tuple of (features, labels), which follows
the output from data loader, i.e., Parser. The output from Parser is fed
into train_step to perform one step forward and backward pass. Other data
structure, such as dictionary, can also be used, as long as it is consistent
between output from Parser and input used here.
Args:
inputs: A tuple of of input tensors of (features, labels).
model: A tf.keras.Model instance.
optimizer: The optimizer for this training step.
metrics: A nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
outputs = model(features, training=True)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss.
loss = self.build_losses(
model_outputs=outputs, labels=labels, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
return logs
def validation_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
metrics: Optional[List[Any]] = None) -> Mapping[str, Any]:
"""Runs validatation step.
Args:
inputs: A tuple of of input tensors of (features, labels).
model: A tf.keras.Model instance.
metrics: A nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
outputs = self.inference_step(features, model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
loss = self.build_losses(
model_outputs=outputs, labels=labels, aux_losses=model.losses)
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
return logs
def inference_step(self, inputs: tf.Tensor, model: tf.keras.Model) -> Any:
"""Performs the forward step. It is used in validation_step."""
return model(inputs, training=False)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All necessary imports for registration.
Custom models, task, configs, etc need to be imported to registry so they can be
picked up by the trainer. They can be included in this file so you do not need
to handle each file separately.
"""
# pylint: disable=unused-import
from official.common import registry_imports
from official.vision.beta.projects.example import example_config
from official.vision.beta.projects.example import example_input
from official.vision.beta.projects.example import example_model
from official.vision.beta.projects.example import example_task
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision trainer.
All custom registry are imported from registry_imports. Here we use default
trainer so we directly call train.main. If you need to customize the trainer,
branch from `official/vision/beta/train.py` and make changes.
"""
from absl import app
from official.common import flags as tfm_flags
from official.vision.beta import train
from official.vision.beta.projects.example import registry_imports # pylint: disable=unused-import
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(train.main)
......@@ -8,6 +8,8 @@ This repository is the official implementation of
[MoViNets: Mobile Video Networks for Efficient Video
Recognition](https://arxiv.org/abs/2103.11511).
**[UPDATE 2021-07-12] Mobile Models Available via [TF Lite](#tf-lite-streaming-models)**
<p align="center">
<img src="https://storage.googleapis.com/tf_model_garden/vision/movinet/artifacts/hoverboard_stream.gif" height=500>
</p>
......@@ -53,6 +55,8 @@ approach that performs redundant computation and limits temporal scope.
## History
- **2021-07-12** Add TF Lite support and replace 3D stream models with
mobile-friendly (2+1)D stream.
- **2021-05-30** Add streaming MoViNet checkpoints and examples.
- **2021-05-11** Initial Commit.
......@@ -68,6 +72,7 @@ approach that performs redundant computation and limits temporal scope.
- [Results and Pretrained Weights](#results-and-pretrained-weights)
- [Kinetics 600](#kinetics-600)
- [Prediction Examples](#prediction-examples)
- [TF Lite Example](#tf-lite-example)
- [Training and Evaluation](#training-and-evaluation)
- [References](#references)
- [License](#license)
......@@ -108,10 +113,14 @@ MoViNet-A5.
#### Base Models
Base models implement standard 3D convolutions without stream buffers.
Base models implement standard 3D convolutions without stream buffers. Base
models are not recommended for fast inference on CPU or mobile due to
limited support for
[`tf.nn.conv3d`](https://www.tensorflow.org/api_docs/python/tf/nn/conv3d).
Instead, see the [streaming models section](#streaming-models).
| Model Name | Top-1 Accuracy | Top-5 Accuracy | Input Shape | GFLOPs\* | Chekpoint | TF Hub SavedModel |
|------------|----------------|----------------|-------------|----------|-----------|-------------------|
| Model Name | Top-1 Accuracy | Top-5 Accuracy | Input Shape | GFLOPs\* | Checkpoint | TF Hub SavedModel |
|------------|----------------|----------------|-------------|----------|------------|-------------------|
| MoViNet-A0-Base | 72.28 | 90.92 | 50 x 172 x 172 | 2.7 | [checkpoint (12 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a0_base.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a0/base/kinetics-600/classification/) |
| MoViNet-A1-Base | 76.69 | 93.40 | 50 x 172 x 172 | 6.0 | [checkpoint (18 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a1_base.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a1/base/kinetics-600/classification/) |
| MoViNet-A2-Base | 78.62 | 94.17 | 50 x 224 x 224 | 10 | [checkpoint (20 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a2_base.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a2/base/kinetics-600/classification/) |
......@@ -123,10 +132,19 @@ Base models implement standard 3D convolutions without stream buffers.
#### Streaming Models
Streaming models implement causal 3D convolutions with stream buffers.
Streaming models implement causal (2+1)D convolutions with stream buffers.
Streaming models use (2+1)D convolution instead of 3D to utilize optimized
[`tf.nn.conv2d`](https://www.tensorflow.org/api_docs/python/tf/nn/conv2d)
operations, which offer fast inference on CPU. Streaming models can be run on
individual frames or on larger video clips like base models.
Note: A3, A4, and A5 models use a positional encoding in the squeeze-excitation
blocks, while A0, A1, and A2 do not. For the smaller models, accuracy is
unaffected without positional encoding, while for the larger models accuracy is
significantly worse without positional encoding.
| Model Name | Top-1 Accuracy | Top-5 Accuracy | Input Shape\* | GFLOPs\*\* | Chekpoint | TF Hub SavedModel |
|------------|----------------|----------------|---------------|------------|-----------|-------------------|
| Model Name | Top-1 Accuracy | Top-5 Accuracy | Input Shape\* | GFLOPs\*\* | Checkpoint | TF Hub SavedModel |
|------------|----------------|----------------|---------------|------------|------------|-------------------|
| MoViNet-A0-Stream | 72.05 | 90.63 | 50 x 172 x 172 | 2.7 | [checkpoint (12 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a0_stream.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a0/stream/kinetics-600/classification/) |
| MoViNet-A1-Stream | 76.45 | 93.25 | 50 x 172 x 172 | 6.0 | [checkpoint (18 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a1_stream.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a1/stream/kinetics-600/classification/) |
| MoViNet-A2-Stream | 78.40 | 94.05 | 50 x 224 x 224 | 10 | [checkpoint (20 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a2_stream.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a2/stream/kinetics-600/classification/) |
......@@ -139,6 +157,35 @@ duration of the 10-second clip.
\*\*GFLOPs per video on Kinetics 600.
Note: current streaming model checkpoints have been updated with a slightly
different architecture. To download the old checkpoints, insert `_legacy` before
`.tar.gz` in the URL. E.g., `movinet_a0_stream_legacy.tar.gz`.
##### TF Lite Streaming Models
For convenience, we provide converted TF Lite models for inference on mobile
devices. See the [TF Lite Example](#tf-lite-example) to export and run your own
models.
For reference, MoViNet-A0-Stream runs with a similar latency to
[MobileNetV3-Large]
(https://tfhub.dev/google/imagenet/mobilenet_v3_large_100_224/classification/)
with +5% accuracy on Kinetics 600.
| Model Name | Input Shape | Pixel 4 Latency\* | x86 Latency\* | TF Lite Binary |
|------------|-------------|-------------------|---------------|----------------|
| MoViNet-A0-Stream | 1 x 1 x 172 x 172 | 22 ms | 16 ms | [TF Lite (13 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a0_stream.tflite) |
| MoViNet-A1-Stream | 1 x 1 x 172 x 172 | 42 ms | 33 ms | [TF Lite (45 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a1_stream.tflite) |
| MoViNet-A2-Stream | 1 x 1 x 224 x 224 | 200 ms | 66 ms | [TF Lite (53 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a2_stream.tflite) |
| MoViNet-A3-Stream | 1 x 1 x 256 x 256 | - | 120 ms | [TF Lite (73 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a3_stream.tflite) |
| MoViNet-A4-Stream | 1 x 1 x 290 x 290 | - | 300 ms | [TF Lite (101 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a4_stream.tflite) |
| MoViNet-A5-Stream | 1 x 1 x 320 x 320 | - | 450 ms | [TF Lite (153 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a5_stream.tflite) |
\*Single-frame latency measured on with unaltered float32 operations on a
single CPU core. Observed latency may differ depending on hardware
configuration. Measured on a stock Pixel 4 (Android 11) and x86 Intel Xeon
W-2135 CPU.
## Prediction Examples
Please check out our [Colab Notebook](https://colab.research.google.com/github/tensorflow/models/tree/master/official/vision/beta/projects/movinet/movinet_tutorial.ipynb)
......@@ -146,7 +193,7 @@ to get started with MoViNets.
This section provides examples on how to run prediction.
For base models, run the following:
For **base models**, run the following:
```python
import tensorflow as tf
......@@ -181,7 +228,7 @@ output = model(inputs)
prediction = tf.argmax(output, -1)
```
For streaming models, run the following:
For **streaming models**, run the following:
```python
import tensorflow as tf
......@@ -189,20 +236,31 @@ import tensorflow as tf
from official.vision.beta.projects.movinet.modeling import movinet
from official.vision.beta.projects.movinet.modeling import movinet_model
model_id = 'a0'
use_positional_encoding = model_id in {'a3', 'a4', 'a5'}
# Create backbone and model.
backbone = movinet.Movinet(
model_id='a0',
model_id=model_id,
causal=True,
conv_type='2plus1d',
se_type='2plus3d',
activation='hard_swish',
gating_activation='hard_sigmoid',
use_positional_encoding=use_positional_encoding,
use_external_states=True,
)
model = movinet_model.MovinetClassifier(
backbone, num_classes=600, output_states=True)
backbone,
num_classes=600,
output_states=True)
# Create your example input here.
# Refer to the paper for recommended input shapes.
inputs = tf.ones([1, 8, 172, 172, 3])
# [Optional] Build the model and load a pretrained checkpoint
# [Optional] Build the model and load a pretrained checkpoint.
model.build(inputs.shape)
checkpoint_dir = '/path/to/checkpoint'
......@@ -237,23 +295,89 @@ non_streaming_output, _ = model({**init_states, 'image': inputs})
non_streaming_prediction = tf.argmax(non_streaming_output, -1)
```
## TF Lite Example
This section outlines an example on how to export a model to run on mobile
devices with [TF Lite](https://www.tensorflow.org/lite).
First, convert to [TF SavedModel](https://www.tensorflow.org/guide/saved_model)
by running `export_saved_model.py`. For example, for `MoViNet-A0-Stream`, run:
```shell
python3 export_saved_model.py \
--model_id=a0 \
--causal=True \
--conv_type=2plus1d \
--se_type=2plus3d \
--activation=hard_swish \
--gating_activation=hard_sigmoid \
--use_positional_encoding=False \
--num_classes=600 \
--batch_size=1 \
--num_frames=1 \
--image_size=172 \
--bundle_input_init_states_fn=False \
--checkpoint_path=/path/to/checkpoint \
--export_path=/tmp/movinet_a0_stream
```
Then the SavedModel can be converted to TF Lite using the [`TFLiteConverter`](https://www.tensorflow.org/lite/convert):
```python
saved_model_dir = '/tmp/movinet_a0_stream'
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
with open('/tmp/movinet_a0_stream.tflite', 'wb') as f:
f.write(tflite_model)
```
To run with TF Lite using [tf.lite.Interpreter](https://www.tensorflow.org/lite/guide/inference#load_and_run_a_model_in_python)
with the Python API:
```python
# Create the interpreter and signature runner
interpreter = tf.lite.Interpreter('/tmp/movinet_a0_stream.tflite')
signature = interpreter.get_signature_runner()
# Extract state names and create the initial (zero) states
def state_name(name: str) -> str:
return name[len('serving_default_'):-len(':0')]
init_states = {
state_name(x['name']): tf.zeros(x['shape'], dtype=x['dtype'])
for x in interpreter.get_input_details()
}
del init_states['image']
# Insert your video clip here
video = tf.ones([1, 8, 172, 172, 3])
clips = tf.split(video, video.shape[1], axis=1)
# To run on a video, pass in one frame at a time
states = init_states
for clip in clips:
# Input shape: [1, 1, 172, 172, 3]
outputs = signature(**states, image=clip)
logits = outputs.pop('logits')
states = outputs
```
Follow the [official guide](https://www.tensorflow.org/lite/guide) to run a
model with TF Lite on your mobile device.
## Training and Evaluation
Run this command line for continuous training and evaluation.
```shell
MODE=train_and_eval # Can also be 'train'
MODE=train_and_eval # Can also be 'train' if using a separate evaluator job
CONFIG_FILE=official/vision/beta/projects/movinet/configs/yaml/movinet_a0_k600_8x8.yaml
python3 official/vision/beta/projects/movinet/train.py \
--experiment=movinet_kinetics600 \
--mode=${MODE} \
--model_dir=/tmp/movinet/ \
--config_file=${CONFIG_FILE} \
--params_override="" \
--gin_file="" \
--gin_params="" \
--tpu="" \
--tf_data_service=""
--model_dir=/tmp/movinet_a0_base/ \
--config_file=${CONFIG_FILE}
```
Run this command line for evaluation.
......@@ -264,13 +388,8 @@ CONFIG_FILE=official/vision/beta/projects/movinet/configs/yaml/movinet_a0_k600_8
python3 official/vision/beta/projects/movinet/train.py \
--experiment=movinet_kinetics600 \
--mode=${MODE} \
--model_dir=/tmp/movinet/ \
--config_file=${CONFIG_FILE} \
--params_override="" \
--gin_file="" \
--gin_params="" \
--tpu="" \
--tf_data_service=""
--model_dir=/tmp/movinet_a0_base/ \
--config_file=${CONFIG_FILE}
```
## License
......
......@@ -44,6 +44,13 @@ class Movinet(hyperparams.Config):
# 2plus1d: (2+1)D convolution with Conv2D (2D reshaping)
# 3d_2plus1d: (2+1)D convolution with Conv3D (no 2D reshaping)
conv_type: str = '3d'
# Choose from ['3d', '2d', '2plus3d']
# 3d: default 3D global average pooling.
# 2d: 2D global average pooling.
# 2plus3d: concatenation of 2D and 3D global average pooling.
se_type: str = '3d'
activation: str = 'swish'
gating_activation: str = 'sigmoid'
stochastic_depth_drop_rate: float = 0.2
use_external_states: bool = False
......@@ -123,6 +130,7 @@ class MovinetModel(video_classification.VideoClassificationModel):
norm_momentum=0.99,
norm_epsilon=1e-3,
use_sync_bn=True)
activation: str = 'swish'
output_states: bool = False
......
......@@ -15,6 +15,11 @@ task:
movinet:
model_id: 'a0'
causal: true
# Note: we train with '3d_2plus1d', but convert to '2plus1d' for inference
conv_type: '3d_2plus1d'
se_type: '2plus3d'
activation: 'hard_swish'
gating_activation: 'hard_sigmoid'
stochastic_depth_drop_rate: 0.2
norm_activation:
use_sync_bn: true
......
......@@ -15,6 +15,11 @@ task:
movinet:
model_id: 'a1'
causal: true
# Note: we train with '3d_2plus1d', but convert to '2plus1d' for inference
conv_type: '3d_2plus1d'
se_type: '2plus3d'
activation: 'hard_swish'
gating_activation: 'hard_sigmoid'
stochastic_depth_drop_rate: 0.2
norm_activation:
use_sync_bn: true
......
......@@ -15,10 +15,15 @@ task:
movinet:
model_id: 'a2'
causal: true
# Note: we train with '3d_2plus1d', but convert to '2plus1d' for inference
conv_type: '3d_2plus1d'
se_type: '2plus3d'
activation: 'hard_swish'
gating_activation: 'hard_sigmoid'
stochastic_depth_drop_rate: 0.2
norm_activation:
use_sync_bn: true
dropout_rate: 0.2
dropout_rate: 0.5
train_data:
name: kinetics600
variant_name: rgb
......
......@@ -15,6 +15,11 @@ task:
movinet:
model_id: 'a3'
causal: true
# Note: we train with '3d_2plus1d', but convert to '2plus1d' for inference
conv_type: '3d_2plus1d'
se_type: '2plus3d'
activation: 'hard_swish'
gating_activation: 'hard_sigmoid'
use_positional_encoding: true
stochastic_depth_drop_rate: 0.2
norm_activation:
......
......@@ -15,6 +15,11 @@ task:
movinet:
model_id: 'a4'
causal: true
# Note: we train with '3d_2plus1d', but convert to '2plus1d' for inference
conv_type: '3d_2plus1d'
se_type: '2plus3d'
activation: 'hard_swish'
gating_activation: 'hard_sigmoid'
use_positional_encoding: true
stochastic_depth_drop_rate: 0.2
norm_activation:
......
......@@ -15,6 +15,11 @@ task:
movinet:
model_id: 'a5'
causal: true
# Note: we train with '3d_2plus1d', but convert to '2plus1d' for inference
conv_type: '3d_2plus1d'
se_type: '2plus3d'
activation: 'hard_swish'
gating_activation: 'hard_sigmoid'
use_positional_encoding: true
stochastic_depth_drop_rate: 0.2
norm_activation:
......@@ -42,7 +47,8 @@ task:
validation_data:
name: kinetics600
feature_shape: !!python/tuple
- 120
# Evaluate on 115 frames instead of 120, as the model will get OOM on TPU
- 115
- 320
- 320
- 3
......
......@@ -15,6 +15,11 @@ task:
movinet:
model_id: 't0'
causal: true
# Note: we train with '3d_2plus1d', but convert to '2plus1d' for inference
conv_type: '3d_2plus1d'
se_type: '2plus3d'
activation: 'hard_swish'
gating_activation: 'hard_sigmoid'
stochastic_depth_drop_rate: 0.2
norm_activation:
use_sync_bn: true
......
......@@ -28,6 +28,26 @@ python3 export_saved_model.py \
--checkpoint_path=""
```
Export for TF Lite example:
```shell
python3 export_saved_model.py \
--model_id=a0 \
--causal=True \
--conv_type=2plus1d \
--se_type=2plus3d \
--activation=hard_swish \
--gating_activation=hard_sigmoid \
--use_positional_encoding=False \
--num_classes=600 \
--batch_size=1 \
--num_frames=1 \ # Use a single frame for streaming mode
--image_size=172 \ # Input resolution for the model
--bundle_input_init_states_fn=False \
--checkpoint_path=/path/to/checkpoint \
--export_path=/tmp/movinet_a0_stream
```
To use an exported saved_model, refer to export_saved_model_test.py.
"""
......@@ -53,6 +73,18 @@ flags.DEFINE_string(
'3x3 followed by 5x1 conv). 3d_2plus1d uses (2+1)D convolution with '
'Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 '
'followed by 5x1x1 conv).')
flags.DEFINE_string(
'se_type', '3d',
'3d, 2d, or 2plus3d. 3d uses the default 3D spatiotemporal global average'
'pooling for squeeze excitation. 2d uses 2D spatial global average pooling '
'on each frame. 2plus3d concatenates both 3D and 2D global average '
'pooling.')
flags.DEFINE_string(
'activation', 'swish',
'The main activation to use across layers.')
flags.DEFINE_string(
'gating_activation', 'sigmoid',
'The gating activation to use in squeeze-excitation layers.')
flags.DEFINE_bool(
'use_positional_encoding', False,
'Whether to use positional encoding (only applied when causal=True).')
......@@ -67,6 +99,10 @@ flags.DEFINE_integer(
flags.DEFINE_integer(
'image_size', None,
'The resolution of the input. Set to None for dynamic input.')
flags.DEFINE_bool(
'bundle_input_init_states_fn', True,
'Add init_states as a function signature to the saved model.'
'This is not necessary if the input shape is static (e.g., for TF Lite).')
flags.DEFINE_string(
'checkpoint_path', '',
'Checkpoint path to load. Leave blank for default initialization.')
......@@ -85,21 +121,33 @@ def main(_) -> None:
# Use dimensions of 1 except the channels to export faster,
# since we only really need the last dimension to build and get the output
# states. These dimensions will be set to `None` once the model is built.
# states. These dimensions can be set to `None` once the model is built.
input_shape = [1 if s is None else s for s in input_specs.shape]
activation = FLAGS.activation
if activation == 'swish':
# Override swish activation implementation to remove custom gradients
activation = 'simple_swish'
backbone = movinet.Movinet(
FLAGS.model_id,
model_id=FLAGS.model_id,
causal=FLAGS.causal,
use_positional_encoding=FLAGS.use_positional_encoding,
conv_type=FLAGS.conv_type,
use_external_states=FLAGS.causal,
se_type=FLAGS.se_type,
input_specs=input_specs,
use_positional_encoding=FLAGS.use_positional_encoding)
activation=activation,
gating_activation=FLAGS.gating_activation,
use_sync_bn=False,
use_external_states=FLAGS.causal)
model = movinet_model.MovinetClassifier(
backbone,
num_classes=FLAGS.num_classes,
output_states=FLAGS.causal,
input_specs=dict(image=input_specs))
input_specs=dict(image=input_specs),
# TODO(dankondratyuk): currently set to swish, but will need to
# re-train to use other activations.
activation='simple_swish')
model.build(input_shape)
# Compile model to generate some internal Keras variables.
......@@ -116,7 +164,7 @@ def main(_) -> None:
# with the full output state shapes.
input_image = tf.ones(input_shape)
_, states = model({**model.init_states(input_shape), 'image': input_image})
_, states = model({**states, 'image': input_image})
_ = model({**states, 'image': input_image})
# Create a function to explicitly set the names of the outputs
def predict(inputs):
......@@ -138,7 +186,10 @@ def main(_) -> None:
init_states_fn = init_states_fn.get_concrete_function(
tf.TensorSpec([5], dtype=tf.int32))
signatures = {'call': predict_fn, 'init_states': init_states_fn}
if FLAGS.bundle_input_init_states_fn:
signatures = {'call': predict_fn, 'init_states': init_states_fn}
else:
signatures = predict_fn
tf.keras.models.save_model(
model, FLAGS.export_path, signatures=signatures)
......
......@@ -48,7 +48,7 @@ class ExportSavedModelTest(tf.test.TestCase):
example_input = tf.ones([1, 8, 172, 172, 3])
outputs = model(example_input)
self.assertEqual(outputs.shape, [1, 600])
self.assertAllEqual(outputs.shape, [1, 600])
def test_movinet_export_a0_stream_with_tfhub(self):
saved_model_path = self.get_temp_dir()
......@@ -94,9 +94,55 @@ class ExportSavedModelTest(tf.test.TestCase):
for frame in frames:
outputs, states = model({**states, 'image': frame})
self.assertEqual(outputs.shape, [1, 600])
self.assertAllEqual(outputs.shape, [1, 600])
self.assertNotEmpty(states)
self.assertAllClose(outputs, expected_outputs, 1e-5, 1e-5)
def test_movinet_export_a0_stream_with_tflite(self):
saved_model_path = self.get_temp_dir()
FLAGS.export_path = saved_model_path
FLAGS.model_id = 'a0'
FLAGS.causal = True
FLAGS.conv_type = '2plus1d'
FLAGS.se_type = '2plus3d'
FLAGS.activation = 'hard_swish'
FLAGS.gating_activation = 'hard_sigmoid'
FLAGS.use_positional_encoding = False
FLAGS.num_classes = 600
FLAGS.batch_size = 1
FLAGS.num_frames = 1
FLAGS.image_size = 172
FLAGS.bundle_input_init_states_fn = False
export_saved_model.main('unused_args')
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path)
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
signature = interpreter.get_signature_runner()
def state_name(name: str) -> str:
return name[len('serving_default_'):-len(':0')]
init_states = {
state_name(x['name']): tf.zeros(x['shape'], dtype=x['dtype'])
for x in interpreter.get_input_details()
}
del init_states['image']
video = tf.ones([1, 8, 172, 172, 3])
clips = tf.split(video, video.shape[1], axis=1)
states = init_states
for clip in clips:
outputs = signature(**states, image=clip)
logits = outputs.pop('logits')
states = outputs
self.assertAllEqual(logits.shape, [1, 600])
self.assertNotEmpty(states)
if __name__ == '__main__':
tf.test.main()
......@@ -43,6 +43,9 @@ S12: KernelSize = (1, 2, 2)
S22: KernelSize = (2, 2, 2)
S21: KernelSize = (2, 1, 1)
# Type for a state container (map)
TensorMap = Mapping[str, tf.Tensor]
@dataclasses.dataclass
class BlockSpec:
......@@ -307,8 +310,10 @@ class Movinet(tf.keras.Model):
causal: bool = False,
use_positional_encoding: bool = False,
conv_type: str = '3d',
se_type: str = '3d',
input_specs: Optional[tf.keras.layers.InputSpec] = None,
activation: str = 'swish',
gating_activation: str = 'sigmoid',
use_sync_bn: bool = True,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
......@@ -317,6 +322,7 @@ class Movinet(tf.keras.Model):
bias_regularizer: Optional[str] = None,
stochastic_depth_drop_rate: float = 0.,
use_external_states: bool = False,
output_states: bool = True,
**kwargs):
"""MoViNet initialization function.
......@@ -332,8 +338,13 @@ class Movinet(tf.keras.Model):
3x3 followed by 5x1 conv). '3d_2plus1d' uses (2+1)D convolution with
Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 followed
by 5x1x1 conv).
se_type: '3d', '2d', or '2plus3d'. '3d' uses the default 3D
spatiotemporal global average pooling for squeeze excitation. '2d'
uses 2D spatial global average pooling on each frame. '2plus3d'
concatenates both 3D and 2D global average pooling.
input_specs: the model input spec to use.
activation: name of the activation function.
activation: name of the main activation function.
gating_activation: gating activation to use in squeeze excitation layers.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: normalization momentum for the moving average.
norm_epsilon: small float added to variance to avoid dividing by
......@@ -346,6 +357,10 @@ class Movinet(tf.keras.Model):
stochastic_depth_drop_rate: the base rate for stochastic depth.
use_external_states: if True, expects states to be passed as additional
input.
output_states: if True, output intermediate states that can be used to run
the model in streaming mode. Inputting the output states of the
previous input clip with the current input clip will utilize a stream
buffer for streaming video.
**kwargs: keyword arguments to be passed.
"""
block_specs = BLOCK_SPECS[model_id]
......@@ -354,15 +369,19 @@ class Movinet(tf.keras.Model):
if conv_type not in ('3d', '2plus1d', '3d_2plus1d'):
raise ValueError('Unknown conv type: {}'.format(conv_type))
if se_type not in ('3d', '2d', '2plus3d'):
raise ValueError('Unknown squeeze excitation type: {}'.format(se_type))
self._model_id = model_id
self._block_specs = block_specs
self._causal = causal
self._use_positional_encoding = use_positional_encoding
self._conv_type = conv_type
self._se_type = se_type
self._input_specs = input_specs
self._use_sync_bn = use_sync_bn
self._activation = activation
self._gating_activation = gating_activation
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
if use_sync_bn:
......@@ -374,6 +393,7 @@ class Movinet(tf.keras.Model):
self._bias_regularizer = bias_regularizer
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._use_external_states = use_external_states
self._output_states = output_states
if self._use_external_states and not self._causal:
raise ValueError('External states should be used with causal mode.')
......@@ -400,8 +420,7 @@ class Movinet(tf.keras.Model):
self,
input_specs: tf.keras.layers.InputSpec,
state_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None,
) -> Tuple[Mapping[str, tf.keras.Input], Tuple[Mapping[str, tf.Tensor],
Mapping[str, tf.Tensor]]]:
) -> Tuple[TensorMap, Union[TensorMap, Tuple[TensorMap, TensorMap]]]:
"""Builds the model network.
Args:
......@@ -412,7 +431,7 @@ class Movinet(tf.keras.Model):
Returns:
Inputs and outputs as a tuple. Inputs are expected to be a dict with
base input and states. Outputs are expected to be a dict of endpoints
and output states.
and (optional) output states.
"""
state_specs = state_specs if state_specs is not None else {}
......@@ -475,10 +494,12 @@ class Movinet(tf.keras.Model):
strides=strides,
causal=self._causal,
activation=self._activation,
gating_activation=self._gating_activation,
stochastic_depth_drop_rate=stochastic_depth_drop_rate,
conv_type=self._conv_type,
use_positional_encoding=self._use_positional_encoding and
self._causal,
se_type=self._se_type,
use_positional_encoding=
self._use_positional_encoding and self._causal,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
batch_norm_layer=self._norm,
......@@ -506,7 +527,7 @@ class Movinet(tf.keras.Model):
else:
raise ValueError('Unknown block type {}'.format(block))
outputs = (endpoints, states)
outputs = (endpoints, states) if self._output_states else endpoints
return inputs, outputs
......@@ -666,6 +687,8 @@ class Movinet(tf.keras.Model):
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'use_external_states': self._use_external_states,
'output_states': self._output_states,
}
return config_dict
......@@ -691,8 +714,10 @@ def build_movinet(
causal=backbone_cfg.causal,
use_positional_encoding=backbone_cfg.use_positional_encoding,
conv_type=backbone_cfg.conv_type,
se_type=backbone_cfg.se_type,
input_specs=input_specs,
activation=norm_activation_config.activation,
activation=backbone_cfg.activation,
gating_activation=backbone_cfg.gating_activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
......
......@@ -265,7 +265,7 @@ class ConvBlock(tf.keras.layers.Layer):
tf.keras.regularizers.L2(KERNEL_WEIGHT_DECAY),
use_batch_norm: bool = True,
batch_norm_layer: tf.keras.layers.Layer =
tf.keras.layers.experimental.SyncBatchNormalization,
tf.keras.layers.BatchNormalization,
batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3,
activation: Optional[Any] = None,
......@@ -547,8 +547,8 @@ class StreamConvBlock(ConvBlock):
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras
.regularizers.L2(KERNEL_WEIGHT_DECAY),
use_batch_norm: bool = True,
batch_norm_layer: tf.keras.layers.Layer = tf.keras.layers.experimental
.SyncBatchNormalization,
batch_norm_layer: tf.keras.layers.Layer =
tf.keras.layers.BatchNormalization,
batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3,
activation: Optional[Any] = None,
......@@ -669,6 +669,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
def __init__(
self,
hidden_filters: int,
se_type: str = '3d',
activation: nn_layers.Activation = 'swish',
gating_activation: nn_layers.Activation = 'sigmoid',
causal: bool = False,
......@@ -683,6 +684,10 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
Args:
hidden_filters: The hidden filters of squeeze excite.
se_type: '3d', '2d', or '2plus3d'. '3d' uses the default 3D
spatiotemporal global average pooling for squeeze excitation. '2d'
uses 2D spatial global average pooling on each frame. '2plus3d'
concatenates both 3D and 2D global average pooling.
activation: name of the activation function.
gating_activation: name of the activation function for gating.
causal: if True, use causal mode in the global average pool.
......@@ -700,6 +705,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
super(StreamSqueezeExcitation, self).__init__(**kwargs)
self._hidden_filters = hidden_filters
self._se_type = se_type
self._activation = activation
self._gating_activation = gating_activation
self._causal = causal
......@@ -709,8 +715,9 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
self._use_positional_encoding = use_positional_encoding
self._state_prefix = state_prefix
self._pool = nn_layers.GlobalAveragePool3D(
self._spatiotemporal_pool = nn_layers.GlobalAveragePool3D(
keepdims=True, causal=causal, state_prefix=state_prefix)
self._spatial_pool = nn_layers.SpatialAveragePool3D(keepdims=True)
self._pos_encoding = None
if use_positional_encoding:
......@@ -721,6 +728,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
"""Returns a dictionary containing the config used for initialization."""
config = {
'hidden_filters': self._hidden_filters,
'se_type': self._se_type,
'activation': self._activation,
'gating_activation': self._gating_activation,
'causal': self._causal,
......@@ -777,13 +785,28 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
"""
states = dict(states) if states is not None else {}
x, states = self._pool(inputs, states=states)
if self._se_type == '3d':
x, states = self._spatiotemporal_pool(inputs, states=states)
elif self._se_type == '2d':
x = self._spatial_pool(inputs)
elif self._se_type == '2plus3d':
x_space = self._spatial_pool(inputs)
x, states = self._spatiotemporal_pool(x_space, states=states)
if not self._causal:
x = tf.tile(x, [1, tf.shape(inputs)[1], 1, 1, 1])
x = tf.concat([x, x_space], axis=-1)
else:
raise ValueError('Unknown Squeeze Excitation type {}'.format(
self._se_type))
if self._pos_encoding is not None:
x, states = self._pos_encoding(x, states=states)
x = self._se_reduce(x)
x = self._se_expand(x)
return x * inputs, states
......@@ -892,7 +915,7 @@ class SkipBlock(tf.keras.layers.Layer):
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] =
tf.keras.regularizers.L2(KERNEL_WEIGHT_DECAY),
batch_norm_layer: tf.keras.layers.Layer =
tf.keras.layers.experimental.SyncBatchNormalization,
tf.keras.layers.BatchNormalization,
batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3,
**kwargs):
......@@ -999,15 +1022,17 @@ class MovinetBlock(tf.keras.layers.Layer):
strides: Union[int, Sequence[int]] = (1, 1, 1),
causal: bool = False,
activation: nn_layers.Activation = 'swish',
gating_activation: nn_layers.Activation = 'sigmoid',
se_ratio: float = 0.25,
stochastic_depth_drop_rate: float = 0.,
conv_type: str = '3d',
se_type: str = '3d',
use_positional_encoding: bool = False,
kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras
.regularizers.L2(KERNEL_WEIGHT_DECAY),
batch_norm_layer: tf.keras.layers.Layer = tf.keras.layers.experimental
.SyncBatchNormalization,
batch_norm_layer: tf.keras.layers.Layer =
tf.keras.layers.BatchNormalization,
batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3,
state_prefix: Optional[str] = None,
......@@ -1021,12 +1046,17 @@ class MovinetBlock(tf.keras.layers.Layer):
strides: strides of the main depthwise convolution.
causal: if True, run the temporal convolutions in causal mode.
activation: activation to use across all conv operations.
gating_activation: gating activation to use in squeeze excitation layers.
se_ratio: squeeze excite filters ratio.
stochastic_depth_drop_rate: optional drop rate for stochastic depth.
conv_type: '3d', '2plus1d', or '3d_2plus1d'. '3d' uses the default 3D
ops. '2plus1d' split any 3D ops into two sequential 2D ops with their
own batch norm and activation. '3d_2plus1d' is like '2plus1d', but
uses two sequential 3D ops instead.
se_type: '3d', '2d', or '2plus3d'. '3d' uses the default 3D
spatiotemporal global average pooling for squeeze excitation. '2d'
uses 2D spatial global average pooling on each frame. '2plus3d'
concatenates both 3D and 2D global average pooling.
use_positional_encoding: add a positional encoding after the (cumulative)
global average pooling layer in the squeeze excite layer.
kernel_initializer: kernel initializer for the conv operations.
......@@ -1042,17 +1072,20 @@ class MovinetBlock(tf.keras.layers.Layer):
self._kernel_size = normalize_tuple(kernel_size, 3, 'kernel_size')
self._strides = normalize_tuple(strides, 3, 'strides')
# Use a multiplier of 2 if concatenating multiple features
se_multiplier = 2 if se_type == '2plus3d' else 1
se_hidden_filters = nn_layers.make_divisible(
se_ratio * expand_filters, divisor=8)
se_ratio * expand_filters * se_multiplier, divisor=8)
self._out_filters = out_filters
self._expand_filters = expand_filters
self._kernel_size = kernel_size
self._causal = causal
self._activation = activation
self._gating_activation = gating_activation
self._se_ratio = se_ratio
self._downsample = any(s > 1 for s in self._strides)
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._conv_type = conv_type
self._se_type = se_type
self._use_positional_encoding = use_positional_encoding
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
......@@ -1103,7 +1136,9 @@ class MovinetBlock(tf.keras.layers.Layer):
name='projection')
self._attention = StreamSqueezeExcitation(
se_hidden_filters,
se_type=se_type,
activation=activation,
gating_activation=gating_activation,
causal=self._causal,
conv_type=conv_type,
use_positional_encoding=use_positional_encoding,
......@@ -1121,9 +1156,11 @@ class MovinetBlock(tf.keras.layers.Layer):
'strides': self._strides,
'causal': self._causal,
'activation': self._activation,
'gating_activation': self._gating_activation,
'se_ratio': self._se_ratio,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'conv_type': self._conv_type,
'se_type': self._se_type,
'use_positional_encoding': self._use_positional_encoding,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
......@@ -1194,8 +1231,8 @@ class Stem(tf.keras.layers.Layer):
kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras
.regularizers.L2(KERNEL_WEIGHT_DECAY),
batch_norm_layer: tf.keras.layers.Layer = tf.keras.layers.experimental
.SyncBatchNormalization,
batch_norm_layer: tf.keras.layers.Layer =
tf.keras.layers.BatchNormalization,
batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3,
state_prefix: Optional[str] = None,
......@@ -1302,8 +1339,8 @@ class Head(tf.keras.layers.Layer):
kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras
.regularizers.L2(KERNEL_WEIGHT_DECAY),
batch_norm_layer: tf.keras.layers.Layer = tf.keras.layers.experimental
.SyncBatchNormalization,
batch_norm_layer: tf.keras.layers.Layer =
tf.keras.layers.BatchNormalization,
batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3,
state_prefix: Optional[str] = None,
......@@ -1432,6 +1469,7 @@ class ClassifierHead(tf.keras.layers.Layer):
self._num_classes = num_classes
self._dropout_rate = dropout_rate
self._conv_type = conv_type
self._activation = activation
self._output_activation = output_activation
self._max_pool_predictions = max_pool_predictions
self._kernel_initializer = kernel_initializer
......@@ -1471,6 +1509,7 @@ class ClassifierHead(tf.keras.layers.Layer):
'num_classes': self._num_classes,
'dropout_rate': self._dropout_rate,
'conv_type': self._conv_type,
'activation': self._activation,
'output_activation': self._output_activation,
'max_pool_predictions': self._max_pool_predictions,
'kernel_initializer': self._kernel_initializer,
......
......@@ -314,6 +314,43 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase):
[[4., 4., 4.]]]]],
1e-5, 1e-5)
def test_stream_squeeze_excitation_2plus3d(self):
se = movinet_layers.StreamSqueezeExcitation(
3,
se_type='2plus3d',
causal=True,
activation='hard_swish',
gating_activation='hard_sigmoid',
kernel_initializer='ones')
inputs = tf.range(4, dtype=tf.float32) + 1.
inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
inputs = tf.tile(inputs, [1, 1, 2, 1, 3])
expected, _ = se(inputs)
for num_splits in [1, 2, 4]:
frames = tf.split(inputs, inputs.shape[1] // num_splits, axis=1)
states = {}
predicted = []
for frame in frames:
x, states = se(frame, states=states)
predicted.append(x)
predicted = tf.concat(predicted, axis=1)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
self.assertAllClose(
predicted,
[[[[[1., 1., 1.]],
[[1., 1., 1.]]],
[[[2., 2., 2.]],
[[2., 2., 2.]]],
[[[3., 3., 3.]],
[[3., 3., 3.]]],
[[[4., 4., 4.]],
[[4., 4., 4.]]]]])
def test_stream_movinet_block(self):
block = movinet_layers.MovinetBlock(
out_filters=3,
......
......@@ -36,6 +36,7 @@ class MovinetClassifier(tf.keras.Model):
backbone: tf.keras.Model,
num_classes: int,
input_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None,
activation: str = 'swish',
dropout_rate: float = 0.0,
kernel_initializer: str = 'HeNormal',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
......@@ -48,6 +49,7 @@ class MovinetClassifier(tf.keras.Model):
backbone: A 3d backbone network.
num_classes: Number of classes in classification task.
input_specs: Specs of the input tensor.
activation: name of the main activation function.
dropout_rate: Rate for dropout regularization.
kernel_initializer: Kernel initializer for the final dense layer.
kernel_regularizer: Kernel regularizer.
......@@ -65,6 +67,7 @@ class MovinetClassifier(tf.keras.Model):
self._num_classes = num_classes
self._input_specs = input_specs
self._activation = activation
self._dropout_rate = dropout_rate
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
......@@ -151,7 +154,8 @@ class MovinetClassifier(tf.keras.Model):
dropout_rate=self._dropout_rate,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
conv_type=backbone.conv_type)(
conv_type=backbone.conv_type,
activation=self._activation)(
x)
outputs = (x, states) if self._output_states else x
......@@ -180,6 +184,7 @@ class MovinetClassifier(tf.keras.Model):
def get_config(self):
config = {
'backbone': self._backbone,
'activation': self._activation,
'num_classes': self._num_classes,
'input_specs': self._input_specs,
'dropout_rate': self._dropout_rate,
......@@ -226,6 +231,7 @@ def build_movinet_model(
num_classes=num_classes,
kernel_regularizer=l2_regularizer,
input_specs=input_specs_dict,
activation=model_config.activation,
dropout_rate=model_config.dropout_rate,
output_states=model_config.output_states)
......
......@@ -131,6 +131,37 @@ class MovinetModelTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected, 1e-5, 1e-5)
def test_movinet_classifier_mobile(self):
"""Test if the model can run with mobile parameters."""
tf.keras.backend.set_image_data_format('channels_last')
backbone = movinet.Movinet(
model_id='a0',
causal=True,
use_external_states=True,
conv_type='2plus1d',
se_type='2plus3d',
activation='hard_swish',
gating_activation='hard_sigmoid'
)
model = movinet_model.MovinetClassifier(
backbone, num_classes=600, output_states=True)
inputs = tf.ones([1, 8, 172, 172, 3])
init_states = model.init_states(tf.shape(inputs))
expected, _ = model({**init_states, 'image': inputs})
frames = tf.split(inputs, inputs.shape[1], axis=1)
states = init_states
for frame in frames:
output, states = model({**states, 'image': frame})
predicted = output
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected, 1e-5, 1e-5)
def test_serialize_deserialize(self):
"""Validate the classification network can be serialized and deserialized."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment