Commit 08f9393a authored by Dan Kondratyuk's avatar Dan Kondratyuk Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 386105413
parent ab4c1d07
...@@ -26,10 +26,6 @@ from official.modeling import tf_utils ...@@ -26,10 +26,6 @@ from official.modeling import tf_utils
States = Dict[str, tf.Tensor] States = Dict[str, tf.Tensor]
Activation = Union[str, Callable] Activation = Union[str, Callable]
# TODO(dankondratyuk): keep legacy padding until new checkpoints are trained.
# Otherwise, accuracy will be affected.
LEGACY_PADDING = True
def make_divisible(value: float, def make_divisible(value: float,
divisor: int, divisor: int,
...@@ -89,6 +85,22 @@ def hard_swish(x: tf.Tensor) -> tf.Tensor: ...@@ -89,6 +85,22 @@ def hard_swish(x: tf.Tensor) -> tf.Tensor:
tf.keras.utils.get_custom_objects().update({'hard_swish': hard_swish}) tf.keras.utils.get_custom_objects().update({'hard_swish': hard_swish})
def simple_swish(x: tf.Tensor) -> tf.Tensor:
"""A swish/silu activation function without custom gradients.
Useful for exporting to SavedModel to avoid custom gradient warnings.
Args:
x: the input tensor.
Returns:
The activation output.
"""
return x * tf.math.sigmoid(x)
tf.keras.utils.get_custom_objects().update({'simple_swish': simple_swish})
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class SqueezeExcitation(tf.keras.layers.Layer): class SqueezeExcitation(tf.keras.layers.Layer):
"""Creates a squeeze and excitation layer.""" """Creates a squeeze and excitation layer."""
...@@ -752,14 +764,10 @@ class CausalConvMixin: ...@@ -752,14 +764,10 @@ class CausalConvMixin:
(self.kernel_size[i] - 1) * (self.dilation_rate[i] - 1)) (self.kernel_size[i] - 1) * (self.dilation_rate[i] - 1))
for i in range(self.rank) for i in range(self.rank)
] ]
if LEGACY_PADDING: pad_total = [kernel_size_effective[0] - 1]
# Apply legacy padding that does not take into account spatial strides for i in range(1, self.rank):
pad_total = [kernel_size_effective[i] - 1 for i in range(self.rank)] overlap = (input_shape[i] - 1) % self.strides[i] + 1
else: pad_total.append(tf.maximum(kernel_size_effective[i] - overlap, 0))
pad_total = [kernel_size_effective[0] - 1]
for i in range(1, self.rank):
overlap = (input_shape[i] - 1) % self.strides[i] + 1
pad_total.append(tf.maximum(kernel_size_effective[i] - overlap, 0))
pad_beg = [pad_total[i] // 2 for i in range(self.rank)] pad_beg = [pad_total[i] // 2 for i in range(self.rank)]
pad_end = [pad_total[i] - pad_beg[i] for i in range(self.rank)] pad_end = [pad_total[i] - pad_beg[i] for i in range(self.rank)]
padding = [[pad_beg[i], pad_end[i]] for i in range(self.rank)] padding = [[pad_beg[i], pad_end[i]] for i in range(self.rank)]
......
...@@ -24,10 +24,6 @@ from official.vision.beta.modeling.layers import nn_layers ...@@ -24,10 +24,6 @@ from official.vision.beta.modeling.layers import nn_layers
class NNLayersTest(parameterized.TestCase, tf.test.TestCase): class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
def setUp(self):
super().setUp()
nn_layers.LEGACY_PADDING = False
def test_hard_swish(self): def test_hard_swish(self):
activation = tf.keras.layers.Activation('hard_swish') activation = tf.keras.layers.Activation('hard_swish')
output = activation(tf.constant([-3, -1.5, 0, 3])) output = activation(tf.constant([-3, -1.5, 0, 3]))
......
...@@ -8,6 +8,8 @@ This repository is the official implementation of ...@@ -8,6 +8,8 @@ This repository is the official implementation of
[MoViNets: Mobile Video Networks for Efficient Video [MoViNets: Mobile Video Networks for Efficient Video
Recognition](https://arxiv.org/abs/2103.11511). Recognition](https://arxiv.org/abs/2103.11511).
**[UPDATE 2021-07-12] Mobile Models Available via [TF Lite](#tf-lite-streaming-models)**
<p align="center"> <p align="center">
<img src="https://storage.googleapis.com/tf_model_garden/vision/movinet/artifacts/hoverboard_stream.gif" height=500> <img src="https://storage.googleapis.com/tf_model_garden/vision/movinet/artifacts/hoverboard_stream.gif" height=500>
</p> </p>
...@@ -53,6 +55,8 @@ approach that performs redundant computation and limits temporal scope. ...@@ -53,6 +55,8 @@ approach that performs redundant computation and limits temporal scope.
## History ## History
- **2021-07-12** Add TF Lite support and replace 3D stream models with
mobile-friendly (2+1)D stream.
- **2021-05-30** Add streaming MoViNet checkpoints and examples. - **2021-05-30** Add streaming MoViNet checkpoints and examples.
- **2021-05-11** Initial Commit. - **2021-05-11** Initial Commit.
...@@ -68,6 +72,7 @@ approach that performs redundant computation and limits temporal scope. ...@@ -68,6 +72,7 @@ approach that performs redundant computation and limits temporal scope.
- [Results and Pretrained Weights](#results-and-pretrained-weights) - [Results and Pretrained Weights](#results-and-pretrained-weights)
- [Kinetics 600](#kinetics-600) - [Kinetics 600](#kinetics-600)
- [Prediction Examples](#prediction-examples) - [Prediction Examples](#prediction-examples)
- [TF Lite Example](#tf-lite-example)
- [Training and Evaluation](#training-and-evaluation) - [Training and Evaluation](#training-and-evaluation)
- [References](#references) - [References](#references)
- [License](#license) - [License](#license)
...@@ -108,10 +113,14 @@ MoViNet-A5. ...@@ -108,10 +113,14 @@ MoViNet-A5.
#### Base Models #### Base Models
Base models implement standard 3D convolutions without stream buffers. Base models implement standard 3D convolutions without stream buffers. Base
models are not recommended for fast inference on CPU or mobile due to
limited support for
[`tf.nn.conv3d`](https://www.tensorflow.org/api_docs/python/tf/nn/conv3d).
Instead, see the [streaming models section](#streaming-models).
| Model Name | Top-1 Accuracy | Top-5 Accuracy | Input Shape | GFLOPs\* | Chekpoint | TF Hub SavedModel | | Model Name | Top-1 Accuracy | Top-5 Accuracy | Input Shape | GFLOPs\* | Checkpoint | TF Hub SavedModel |
|------------|----------------|----------------|-------------|----------|-----------|-------------------| |------------|----------------|----------------|-------------|----------|------------|-------------------|
| MoViNet-A0-Base | 72.28 | 90.92 | 50 x 172 x 172 | 2.7 | [checkpoint (12 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a0_base.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a0/base/kinetics-600/classification/) | | MoViNet-A0-Base | 72.28 | 90.92 | 50 x 172 x 172 | 2.7 | [checkpoint (12 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a0_base.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a0/base/kinetics-600/classification/) |
| MoViNet-A1-Base | 76.69 | 93.40 | 50 x 172 x 172 | 6.0 | [checkpoint (18 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a1_base.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a1/base/kinetics-600/classification/) | | MoViNet-A1-Base | 76.69 | 93.40 | 50 x 172 x 172 | 6.0 | [checkpoint (18 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a1_base.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a1/base/kinetics-600/classification/) |
| MoViNet-A2-Base | 78.62 | 94.17 | 50 x 224 x 224 | 10 | [checkpoint (20 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a2_base.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a2/base/kinetics-600/classification/) | | MoViNet-A2-Base | 78.62 | 94.17 | 50 x 224 x 224 | 10 | [checkpoint (20 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a2_base.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a2/base/kinetics-600/classification/) |
...@@ -123,10 +132,19 @@ Base models implement standard 3D convolutions without stream buffers. ...@@ -123,10 +132,19 @@ Base models implement standard 3D convolutions without stream buffers.
#### Streaming Models #### Streaming Models
Streaming models implement causal 3D convolutions with stream buffers. Streaming models implement causal (2+1)D convolutions with stream buffers.
Streaming models use (2+1)D convolution instead of 3D to utilize optimized
[`tf.nn.conv2d`](https://www.tensorflow.org/api_docs/python/tf/nn/conv2d)
operations, which offer fast inference on CPU. Streaming models can be run on
individual frames or on larger video clips like base models.
Note: A3, A4, and A5 models use a positional encoding in the squeeze-excitation
blocks, while A0, A1, and A2 do not. For the smaller models, accuracy is
unaffected without positional encoding, while for the larger models accuracy is
significantly worse without positional encoding.
| Model Name | Top-1 Accuracy | Top-5 Accuracy | Input Shape\* | GFLOPs\*\* | Chekpoint | TF Hub SavedModel | | Model Name | Top-1 Accuracy | Top-5 Accuracy | Input Shape\* | GFLOPs\*\* | Checkpoint | TF Hub SavedModel |
|------------|----------------|----------------|---------------|------------|-----------|-------------------| |------------|----------------|----------------|---------------|------------|------------|-------------------|
| MoViNet-A0-Stream | 72.05 | 90.63 | 50 x 172 x 172 | 2.7 | [checkpoint (12 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a0_stream.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a0/stream/kinetics-600/classification/) | | MoViNet-A0-Stream | 72.05 | 90.63 | 50 x 172 x 172 | 2.7 | [checkpoint (12 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a0_stream.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a0/stream/kinetics-600/classification/) |
| MoViNet-A1-Stream | 76.45 | 93.25 | 50 x 172 x 172 | 6.0 | [checkpoint (18 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a1_stream.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a1/stream/kinetics-600/classification/) | | MoViNet-A1-Stream | 76.45 | 93.25 | 50 x 172 x 172 | 6.0 | [checkpoint (18 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a1_stream.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a1/stream/kinetics-600/classification/) |
| MoViNet-A2-Stream | 78.40 | 94.05 | 50 x 224 x 224 | 10 | [checkpoint (20 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a2_stream.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a2/stream/kinetics-600/classification/) | | MoViNet-A2-Stream | 78.40 | 94.05 | 50 x 224 x 224 | 10 | [checkpoint (20 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a2_stream.tar.gz) | [tfhub](https://tfhub.dev/tensorflow/movinet/a2/stream/kinetics-600/classification/) |
...@@ -139,6 +157,35 @@ duration of the 10-second clip. ...@@ -139,6 +157,35 @@ duration of the 10-second clip.
\*\*GFLOPs per video on Kinetics 600. \*\*GFLOPs per video on Kinetics 600.
Note: current streaming model checkpoints have been updated with a slightly
different architecture. To download the old checkpoints, insert `_legacy` before
`.tar.gz` in the URL. E.g., `movinet_a0_stream_legacy.tar.gz`.
##### TF Lite Streaming Models
For convenience, we provide converted TF Lite models for inference on mobile
devices. See the [TF Lite Example](#tf-lite-example) to export and run your own
models.
For reference, MoViNet-A0-Stream runs with a similar latency to
[MobileNetV3-Large]
(https://tfhub.dev/google/imagenet/mobilenet_v3_large_100_224/classification/)
with +5% accuracy on Kinetics 600.
| Model Name | Input Shape | Pixel 4 Latency\* | x86 Latency\* | TF Lite Binary |
|------------|-------------|-------------------|---------------|----------------|
| MoViNet-A0-Stream | 1 x 1 x 172 x 172 | 22 ms | 16 ms | [TF Lite (13 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a0_stream.tflite) |
| MoViNet-A1-Stream | 1 x 1 x 172 x 172 | 42 ms | 33 ms | [TF Lite (45 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a1_stream.tflite) |
| MoViNet-A2-Stream | 1 x 1 x 224 x 224 | 200 ms | 66 ms | [TF Lite (53 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a2_stream.tflite) |
| MoViNet-A3-Stream | 1 x 1 x 256 x 256 | - | 120 ms | [TF Lite (73 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a3_stream.tflite) |
| MoViNet-A4-Stream | 1 x 1 x 290 x 290 | - | 300 ms | [TF Lite (101 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a4_stream.tflite) |
| MoViNet-A5-Stream | 1 x 1 x 320 x 320 | - | 450 ms | [TF Lite (153 MB)](https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a5_stream.tflite) |
\*Single-frame latency measured on with unaltered float32 operations on a
single CPU core. Observed latency may differ depending on hardware
configuration. Measured on a stock Pixel 4 (Android 11) and x86 Intel Xeon
W-2135 CPU.
## Prediction Examples ## Prediction Examples
Please check out our [Colab Notebook](https://colab.research.google.com/github/tensorflow/models/tree/master/official/vision/beta/projects/movinet/movinet_tutorial.ipynb) Please check out our [Colab Notebook](https://colab.research.google.com/github/tensorflow/models/tree/master/official/vision/beta/projects/movinet/movinet_tutorial.ipynb)
...@@ -146,7 +193,7 @@ to get started with MoViNets. ...@@ -146,7 +193,7 @@ to get started with MoViNets.
This section provides examples on how to run prediction. This section provides examples on how to run prediction.
For base models, run the following: For **base models**, run the following:
```python ```python
import tensorflow as tf import tensorflow as tf
...@@ -181,7 +228,7 @@ output = model(inputs) ...@@ -181,7 +228,7 @@ output = model(inputs)
prediction = tf.argmax(output, -1) prediction = tf.argmax(output, -1)
``` ```
For streaming models, run the following: For **streaming models**, run the following:
```python ```python
import tensorflow as tf import tensorflow as tf
...@@ -189,20 +236,31 @@ import tensorflow as tf ...@@ -189,20 +236,31 @@ import tensorflow as tf
from official.vision.beta.projects.movinet.modeling import movinet from official.vision.beta.projects.movinet.modeling import movinet
from official.vision.beta.projects.movinet.modeling import movinet_model from official.vision.beta.projects.movinet.modeling import movinet_model
model_id = 'a0'
use_positional_encoding = model_id in {'a3', 'a4', 'a5'}
# Create backbone and model. # Create backbone and model.
backbone = movinet.Movinet( backbone = movinet.Movinet(
model_id='a0', model_id=model_id,
causal=True, causal=True,
conv_type='2plus1d',
se_type='2plus3d',
activation='hard_swish',
gating_activation='hard_sigmoid',
use_positional_encoding=use_positional_encoding,
use_external_states=True, use_external_states=True,
) )
model = movinet_model.MovinetClassifier( model = movinet_model.MovinetClassifier(
backbone, num_classes=600, output_states=True) backbone,
num_classes=600,
output_states=True)
# Create your example input here. # Create your example input here.
# Refer to the paper for recommended input shapes. # Refer to the paper for recommended input shapes.
inputs = tf.ones([1, 8, 172, 172, 3]) inputs = tf.ones([1, 8, 172, 172, 3])
# [Optional] Build the model and load a pretrained checkpoint # [Optional] Build the model and load a pretrained checkpoint.
model.build(inputs.shape) model.build(inputs.shape)
checkpoint_dir = '/path/to/checkpoint' checkpoint_dir = '/path/to/checkpoint'
...@@ -237,23 +295,89 @@ non_streaming_output, _ = model({**init_states, 'image': inputs}) ...@@ -237,23 +295,89 @@ non_streaming_output, _ = model({**init_states, 'image': inputs})
non_streaming_prediction = tf.argmax(non_streaming_output, -1) non_streaming_prediction = tf.argmax(non_streaming_output, -1)
``` ```
## TF Lite Example
This section outlines an example on how to export a model to run on mobile
devices with [TF Lite](https://www.tensorflow.org/lite).
First, convert to [TF SavedModel](https://www.tensorflow.org/guide/saved_model)
by running `export_saved_model.py`. For example, for `MoViNet-A0-Stream`, run:
```shell
python3 export_saved_model.py \
--model_id=a0 \
--causal=True \
--conv_type=2plus1d \
--se_type=2plus3d \
--activation=hard_swish \
--gating_activation=hard_sigmoid \
--use_positional_encoding=False \
--num_classes=600 \
--batch_size=1 \
--num_frames=1 \
--image_size=172 \
--bundle_input_init_states_fn=False \
--checkpoint_path=/path/to/checkpoint \
--export_path=/tmp/movinet_a0_stream
```
Then the SavedModel can be converted to TF Lite using the [`TFLiteConverter`](https://www.tensorflow.org/lite/convert):
```python
saved_model_dir = '/tmp/movinet_a0_stream'
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
with open('/tmp/movinet_a0_stream.tflite', 'wb') as f:
f.write(tflite_model)
```
To run with TF Lite using [tf.lite.Interpreter](https://www.tensorflow.org/lite/guide/inference#load_and_run_a_model_in_python)
with the Python API:
```python
# Create the interpreter and signature runner
interpreter = tf.lite.Interpreter('/tmp/movinet_a0_stream.tflite')
signature = interpreter.get_signature_runner()
# Extract state names and create the initial (zero) states
def state_name(name: str) -> str:
return name[len('serving_default_'):-len(':0')]
init_states = {
state_name(x['name']): tf.zeros(x['shape'], dtype=x['dtype'])
for x in interpreter.get_input_details()
}
del init_states['image']
# Insert your video clip here
video = tf.ones([1, 8, 172, 172, 3])
clips = tf.split(video, video.shape[1], axis=1)
# To run on a video, pass in one frame at a time
states = init_states
for clip in clips:
# Input shape: [1, 1, 172, 172, 3]
outputs = signature(**states, image=clip)
logits = outputs.pop('logits')
states = outputs
```
Follow the [official guide](https://www.tensorflow.org/lite/guide) to run a
model with TF Lite on your mobile device.
## Training and Evaluation ## Training and Evaluation
Run this command line for continuous training and evaluation. Run this command line for continuous training and evaluation.
```shell ```shell
MODE=train_and_eval # Can also be 'train' MODE=train_and_eval # Can also be 'train' if using a separate evaluator job
CONFIG_FILE=official/vision/beta/projects/movinet/configs/yaml/movinet_a0_k600_8x8.yaml CONFIG_FILE=official/vision/beta/projects/movinet/configs/yaml/movinet_a0_k600_8x8.yaml
python3 official/vision/beta/projects/movinet/train.py \ python3 official/vision/beta/projects/movinet/train.py \
--experiment=movinet_kinetics600 \ --experiment=movinet_kinetics600 \
--mode=${MODE} \ --mode=${MODE} \
--model_dir=/tmp/movinet/ \ --model_dir=/tmp/movinet_a0_base/ \
--config_file=${CONFIG_FILE} \ --config_file=${CONFIG_FILE}
--params_override="" \
--gin_file="" \
--gin_params="" \
--tpu="" \
--tf_data_service=""
``` ```
Run this command line for evaluation. Run this command line for evaluation.
...@@ -264,13 +388,8 @@ CONFIG_FILE=official/vision/beta/projects/movinet/configs/yaml/movinet_a0_k600_8 ...@@ -264,13 +388,8 @@ CONFIG_FILE=official/vision/beta/projects/movinet/configs/yaml/movinet_a0_k600_8
python3 official/vision/beta/projects/movinet/train.py \ python3 official/vision/beta/projects/movinet/train.py \
--experiment=movinet_kinetics600 \ --experiment=movinet_kinetics600 \
--mode=${MODE} \ --mode=${MODE} \
--model_dir=/tmp/movinet/ \ --model_dir=/tmp/movinet_a0_base/ \
--config_file=${CONFIG_FILE} \ --config_file=${CONFIG_FILE}
--params_override="" \
--gin_file="" \
--gin_params="" \
--tpu="" \
--tf_data_service=""
``` ```
## License ## License
......
...@@ -130,6 +130,7 @@ class MovinetModel(video_classification.VideoClassificationModel): ...@@ -130,6 +130,7 @@ class MovinetModel(video_classification.VideoClassificationModel):
norm_momentum=0.99, norm_momentum=0.99,
norm_epsilon=1e-3, norm_epsilon=1e-3,
use_sync_bn=True) use_sync_bn=True)
activation: str = 'swish'
output_states: bool = False output_states: bool = False
......
...@@ -15,6 +15,11 @@ task: ...@@ -15,6 +15,11 @@ task:
movinet: movinet:
model_id: 'a0' model_id: 'a0'
causal: true causal: true
# Note: we train with '3d_2plus1d', but convert to '2plus1d' for inference
conv_type: '3d_2plus1d'
se_type: '2plus3d'
activation: 'hard_swish'
gating_activation: 'hard_sigmoid'
stochastic_depth_drop_rate: 0.2 stochastic_depth_drop_rate: 0.2
norm_activation: norm_activation:
use_sync_bn: true use_sync_bn: true
......
...@@ -15,6 +15,11 @@ task: ...@@ -15,6 +15,11 @@ task:
movinet: movinet:
model_id: 'a1' model_id: 'a1'
causal: true causal: true
# Note: we train with '3d_2plus1d', but convert to '2plus1d' for inference
conv_type: '3d_2plus1d'
se_type: '2plus3d'
activation: 'hard_swish'
gating_activation: 'hard_sigmoid'
stochastic_depth_drop_rate: 0.2 stochastic_depth_drop_rate: 0.2
norm_activation: norm_activation:
use_sync_bn: true use_sync_bn: true
......
...@@ -15,10 +15,15 @@ task: ...@@ -15,10 +15,15 @@ task:
movinet: movinet:
model_id: 'a2' model_id: 'a2'
causal: true causal: true
# Note: we train with '3d_2plus1d', but convert to '2plus1d' for inference
conv_type: '3d_2plus1d'
se_type: '2plus3d'
activation: 'hard_swish'
gating_activation: 'hard_sigmoid'
stochastic_depth_drop_rate: 0.2 stochastic_depth_drop_rate: 0.2
norm_activation: norm_activation:
use_sync_bn: true use_sync_bn: true
dropout_rate: 0.2 dropout_rate: 0.5
train_data: train_data:
name: kinetics600 name: kinetics600
variant_name: rgb variant_name: rgb
......
...@@ -15,6 +15,11 @@ task: ...@@ -15,6 +15,11 @@ task:
movinet: movinet:
model_id: 'a3' model_id: 'a3'
causal: true causal: true
# Note: we train with '3d_2plus1d', but convert to '2plus1d' for inference
conv_type: '3d_2plus1d'
se_type: '2plus3d'
activation: 'hard_swish'
gating_activation: 'hard_sigmoid'
use_positional_encoding: true use_positional_encoding: true
stochastic_depth_drop_rate: 0.2 stochastic_depth_drop_rate: 0.2
norm_activation: norm_activation:
......
...@@ -15,6 +15,11 @@ task: ...@@ -15,6 +15,11 @@ task:
movinet: movinet:
model_id: 'a4' model_id: 'a4'
causal: true causal: true
# Note: we train with '3d_2plus1d', but convert to '2plus1d' for inference
conv_type: '3d_2plus1d'
se_type: '2plus3d'
activation: 'hard_swish'
gating_activation: 'hard_sigmoid'
use_positional_encoding: true use_positional_encoding: true
stochastic_depth_drop_rate: 0.2 stochastic_depth_drop_rate: 0.2
norm_activation: norm_activation:
......
...@@ -15,6 +15,11 @@ task: ...@@ -15,6 +15,11 @@ task:
movinet: movinet:
model_id: 'a5' model_id: 'a5'
causal: true causal: true
# Note: we train with '3d_2plus1d', but convert to '2plus1d' for inference
conv_type: '3d_2plus1d'
se_type: '2plus3d'
activation: 'hard_swish'
gating_activation: 'hard_sigmoid'
use_positional_encoding: true use_positional_encoding: true
stochastic_depth_drop_rate: 0.2 stochastic_depth_drop_rate: 0.2
norm_activation: norm_activation:
...@@ -42,7 +47,8 @@ task: ...@@ -42,7 +47,8 @@ task:
validation_data: validation_data:
name: kinetics600 name: kinetics600
feature_shape: !!python/tuple feature_shape: !!python/tuple
- 120 # Evaluate on 115 frames instead of 120, as the model will get OOM on TPU
- 115
- 320 - 320
- 320 - 320
- 3 - 3
......
...@@ -15,6 +15,11 @@ task: ...@@ -15,6 +15,11 @@ task:
movinet: movinet:
model_id: 't0' model_id: 't0'
causal: true causal: true
# Note: we train with '3d_2plus1d', but convert to '2plus1d' for inference
conv_type: '3d_2plus1d'
se_type: '2plus3d'
activation: 'hard_swish'
gating_activation: 'hard_sigmoid'
stochastic_depth_drop_rate: 0.2 stochastic_depth_drop_rate: 0.2
norm_activation: norm_activation:
use_sync_bn: true use_sync_bn: true
......
...@@ -28,6 +28,26 @@ python3 export_saved_model.py \ ...@@ -28,6 +28,26 @@ python3 export_saved_model.py \
--checkpoint_path="" --checkpoint_path=""
``` ```
Export for TF Lite example:
```shell
python3 export_saved_model.py \
--model_id=a0 \
--causal=True \
--conv_type=2plus1d \
--se_type=2plus3d \
--activation=hard_swish \
--gating_activation=hard_sigmoid \
--use_positional_encoding=False \
--num_classes=600 \
--batch_size=1 \
--num_frames=1 \ # Use a single frame for streaming mode
--image_size=172 \ # Input resolution for the model
--bundle_input_init_states_fn=False \
--checkpoint_path=/path/to/checkpoint \
--export_path=/tmp/movinet_a0_stream
```
To use an exported saved_model, refer to export_saved_model_test.py. To use an exported saved_model, refer to export_saved_model_test.py.
""" """
...@@ -79,6 +99,10 @@ flags.DEFINE_integer( ...@@ -79,6 +99,10 @@ flags.DEFINE_integer(
flags.DEFINE_integer( flags.DEFINE_integer(
'image_size', None, 'image_size', None,
'The resolution of the input. Set to None for dynamic input.') 'The resolution of the input. Set to None for dynamic input.')
flags.DEFINE_bool(
'bundle_input_init_states_fn', True,
'Add init_states as a function signature to the saved model.'
'This is not necessary if the input shape is static (e.g., for TF Lite).')
flags.DEFINE_string( flags.DEFINE_string(
'checkpoint_path', '', 'checkpoint_path', '',
'Checkpoint path to load. Leave blank for default initialization.') 'Checkpoint path to load. Leave blank for default initialization.')
...@@ -97,24 +121,33 @@ def main(_) -> None: ...@@ -97,24 +121,33 @@ def main(_) -> None:
# Use dimensions of 1 except the channels to export faster, # Use dimensions of 1 except the channels to export faster,
# since we only really need the last dimension to build and get the output # since we only really need the last dimension to build and get the output
# states. These dimensions will be set to `None` once the model is built. # states. These dimensions can be set to `None` once the model is built.
input_shape = [1 if s is None else s for s in input_specs.shape] input_shape = [1 if s is None else s for s in input_specs.shape]
activation = FLAGS.activation
if activation == 'swish':
# Override swish activation implementation to remove custom gradients
activation = 'simple_swish'
backbone = movinet.Movinet( backbone = movinet.Movinet(
FLAGS.model_id, model_id=FLAGS.model_id,
causal=FLAGS.causal, causal=FLAGS.causal,
use_positional_encoding=FLAGS.use_positional_encoding,
conv_type=FLAGS.conv_type, conv_type=FLAGS.conv_type,
use_external_states=FLAGS.causal, se_type=FLAGS.se_type,
input_specs=input_specs, input_specs=input_specs,
activation=FLAGS.activation, activation=activation,
gating_activation=FLAGS.gating_activation, gating_activation=FLAGS.gating_activation,
se_type=FLAGS.se_type, use_sync_bn=False,
use_positional_encoding=FLAGS.use_positional_encoding) use_external_states=FLAGS.causal)
model = movinet_model.MovinetClassifier( model = movinet_model.MovinetClassifier(
backbone, backbone,
num_classes=FLAGS.num_classes, num_classes=FLAGS.num_classes,
output_states=FLAGS.causal, output_states=FLAGS.causal,
input_specs=dict(image=input_specs)) input_specs=dict(image=input_specs),
# TODO(dankondratyuk): currently set to swish, but will need to
# re-train to use other activations.
activation='simple_swish')
model.build(input_shape) model.build(input_shape)
# Compile model to generate some internal Keras variables. # Compile model to generate some internal Keras variables.
...@@ -131,7 +164,7 @@ def main(_) -> None: ...@@ -131,7 +164,7 @@ def main(_) -> None:
# with the full output state shapes. # with the full output state shapes.
input_image = tf.ones(input_shape) input_image = tf.ones(input_shape)
_, states = model({**model.init_states(input_shape), 'image': input_image}) _, states = model({**model.init_states(input_shape), 'image': input_image})
_, states = model({**states, 'image': input_image}) _ = model({**states, 'image': input_image})
# Create a function to explicitly set the names of the outputs # Create a function to explicitly set the names of the outputs
def predict(inputs): def predict(inputs):
...@@ -153,7 +186,10 @@ def main(_) -> None: ...@@ -153,7 +186,10 @@ def main(_) -> None:
init_states_fn = init_states_fn.get_concrete_function( init_states_fn = init_states_fn.get_concrete_function(
tf.TensorSpec([5], dtype=tf.int32)) tf.TensorSpec([5], dtype=tf.int32))
signatures = {'call': predict_fn, 'init_states': init_states_fn} if FLAGS.bundle_input_init_states_fn:
signatures = {'call': predict_fn, 'init_states': init_states_fn}
else:
signatures = predict_fn
tf.keras.models.save_model( tf.keras.models.save_model(
model, FLAGS.export_path, signatures=signatures) model, FLAGS.export_path, signatures=signatures)
......
...@@ -48,7 +48,7 @@ class ExportSavedModelTest(tf.test.TestCase): ...@@ -48,7 +48,7 @@ class ExportSavedModelTest(tf.test.TestCase):
example_input = tf.ones([1, 8, 172, 172, 3]) example_input = tf.ones([1, 8, 172, 172, 3])
outputs = model(example_input) outputs = model(example_input)
self.assertEqual(outputs.shape, [1, 600]) self.assertAllEqual(outputs.shape, [1, 600])
def test_movinet_export_a0_stream_with_tfhub(self): def test_movinet_export_a0_stream_with_tfhub(self):
saved_model_path = self.get_temp_dir() saved_model_path = self.get_temp_dir()
...@@ -94,9 +94,55 @@ class ExportSavedModelTest(tf.test.TestCase): ...@@ -94,9 +94,55 @@ class ExportSavedModelTest(tf.test.TestCase):
for frame in frames: for frame in frames:
outputs, states = model({**states, 'image': frame}) outputs, states = model({**states, 'image': frame})
self.assertEqual(outputs.shape, [1, 600]) self.assertAllEqual(outputs.shape, [1, 600])
self.assertNotEmpty(states) self.assertNotEmpty(states)
self.assertAllClose(outputs, expected_outputs, 1e-5, 1e-5) self.assertAllClose(outputs, expected_outputs, 1e-5, 1e-5)
def test_movinet_export_a0_stream_with_tflite(self):
saved_model_path = self.get_temp_dir()
FLAGS.export_path = saved_model_path
FLAGS.model_id = 'a0'
FLAGS.causal = True
FLAGS.conv_type = '2plus1d'
FLAGS.se_type = '2plus3d'
FLAGS.activation = 'hard_swish'
FLAGS.gating_activation = 'hard_sigmoid'
FLAGS.use_positional_encoding = False
FLAGS.num_classes = 600
FLAGS.batch_size = 1
FLAGS.num_frames = 1
FLAGS.image_size = 172
FLAGS.bundle_input_init_states_fn = False
export_saved_model.main('unused_args')
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path)
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
signature = interpreter.get_signature_runner()
def state_name(name: str) -> str:
return name[len('serving_default_'):-len(':0')]
init_states = {
state_name(x['name']): tf.zeros(x['shape'], dtype=x['dtype'])
for x in interpreter.get_input_details()
}
del init_states['image']
video = tf.ones([1, 8, 172, 172, 3])
clips = tf.split(video, video.shape[1], axis=1)
states = init_states
for clip in clips:
outputs = signature(**states, image=clip)
logits = outputs.pop('logits')
states = outputs
self.assertAllEqual(logits.shape, [1, 600])
self.assertNotEmpty(states)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -43,6 +43,9 @@ S12: KernelSize = (1, 2, 2) ...@@ -43,6 +43,9 @@ S12: KernelSize = (1, 2, 2)
S22: KernelSize = (2, 2, 2) S22: KernelSize = (2, 2, 2)
S21: KernelSize = (2, 1, 1) S21: KernelSize = (2, 1, 1)
# Type for a state container (map)
TensorMap = Mapping[str, tf.Tensor]
@dataclasses.dataclass @dataclasses.dataclass
class BlockSpec: class BlockSpec:
...@@ -319,6 +322,7 @@ class Movinet(tf.keras.Model): ...@@ -319,6 +322,7 @@ class Movinet(tf.keras.Model):
bias_regularizer: Optional[str] = None, bias_regularizer: Optional[str] = None,
stochastic_depth_drop_rate: float = 0., stochastic_depth_drop_rate: float = 0.,
use_external_states: bool = False, use_external_states: bool = False,
output_states: bool = True,
**kwargs): **kwargs):
"""MoViNet initialization function. """MoViNet initialization function.
...@@ -353,6 +357,10 @@ class Movinet(tf.keras.Model): ...@@ -353,6 +357,10 @@ class Movinet(tf.keras.Model):
stochastic_depth_drop_rate: the base rate for stochastic depth. stochastic_depth_drop_rate: the base rate for stochastic depth.
use_external_states: if True, expects states to be passed as additional use_external_states: if True, expects states to be passed as additional
input. input.
output_states: if True, output intermediate states that can be used to run
the model in streaming mode. Inputting the output states of the
previous input clip with the current input clip will utilize a stream
buffer for streaming video.
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
block_specs = BLOCK_SPECS[model_id] block_specs = BLOCK_SPECS[model_id]
...@@ -385,6 +393,7 @@ class Movinet(tf.keras.Model): ...@@ -385,6 +393,7 @@ class Movinet(tf.keras.Model):
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._use_external_states = use_external_states self._use_external_states = use_external_states
self._output_states = output_states
if self._use_external_states and not self._causal: if self._use_external_states and not self._causal:
raise ValueError('External states should be used with causal mode.') raise ValueError('External states should be used with causal mode.')
...@@ -411,8 +420,7 @@ class Movinet(tf.keras.Model): ...@@ -411,8 +420,7 @@ class Movinet(tf.keras.Model):
self, self,
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
state_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None, state_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None,
) -> Tuple[Mapping[str, tf.keras.Input], Tuple[Mapping[str, tf.Tensor], ) -> Tuple[TensorMap, Union[TensorMap, Tuple[TensorMap, TensorMap]]]:
Mapping[str, tf.Tensor]]]:
"""Builds the model network. """Builds the model network.
Args: Args:
...@@ -423,7 +431,7 @@ class Movinet(tf.keras.Model): ...@@ -423,7 +431,7 @@ class Movinet(tf.keras.Model):
Returns: Returns:
Inputs and outputs as a tuple. Inputs are expected to be a dict with Inputs and outputs as a tuple. Inputs are expected to be a dict with
base input and states. Outputs are expected to be a dict of endpoints base input and states. Outputs are expected to be a dict of endpoints
and output states. and (optional) output states.
""" """
state_specs = state_specs if state_specs is not None else {} state_specs = state_specs if state_specs is not None else {}
...@@ -519,7 +527,7 @@ class Movinet(tf.keras.Model): ...@@ -519,7 +527,7 @@ class Movinet(tf.keras.Model):
else: else:
raise ValueError('Unknown block type {}'.format(block)) raise ValueError('Unknown block type {}'.format(block))
outputs = (endpoints, states) outputs = (endpoints, states) if self._output_states else endpoints
return inputs, outputs return inputs, outputs
...@@ -679,6 +687,8 @@ class Movinet(tf.keras.Model): ...@@ -679,6 +687,8 @@ class Movinet(tf.keras.Model):
'kernel_regularizer': self._kernel_regularizer, 'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer, 'bias_regularizer': self._bias_regularizer,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate, 'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'use_external_states': self._use_external_states,
'output_states': self._output_states,
} }
return config_dict return config_dict
......
...@@ -265,7 +265,7 @@ class ConvBlock(tf.keras.layers.Layer): ...@@ -265,7 +265,7 @@ class ConvBlock(tf.keras.layers.Layer):
tf.keras.regularizers.L2(KERNEL_WEIGHT_DECAY), tf.keras.regularizers.L2(KERNEL_WEIGHT_DECAY),
use_batch_norm: bool = True, use_batch_norm: bool = True,
batch_norm_layer: tf.keras.layers.Layer = batch_norm_layer: tf.keras.layers.Layer =
tf.keras.layers.experimental.SyncBatchNormalization, tf.keras.layers.BatchNormalization,
batch_norm_momentum: float = 0.99, batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3, batch_norm_epsilon: float = 1e-3,
activation: Optional[Any] = None, activation: Optional[Any] = None,
...@@ -547,8 +547,8 @@ class StreamConvBlock(ConvBlock): ...@@ -547,8 +547,8 @@ class StreamConvBlock(ConvBlock):
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras
.regularizers.L2(KERNEL_WEIGHT_DECAY), .regularizers.L2(KERNEL_WEIGHT_DECAY),
use_batch_norm: bool = True, use_batch_norm: bool = True,
batch_norm_layer: tf.keras.layers.Layer = tf.keras.layers.experimental batch_norm_layer: tf.keras.layers.Layer =
.SyncBatchNormalization, tf.keras.layers.BatchNormalization,
batch_norm_momentum: float = 0.99, batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3, batch_norm_epsilon: float = 1e-3,
activation: Optional[Any] = None, activation: Optional[Any] = None,
...@@ -915,7 +915,7 @@ class SkipBlock(tf.keras.layers.Layer): ...@@ -915,7 +915,7 @@ class SkipBlock(tf.keras.layers.Layer):
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] =
tf.keras.regularizers.L2(KERNEL_WEIGHT_DECAY), tf.keras.regularizers.L2(KERNEL_WEIGHT_DECAY),
batch_norm_layer: tf.keras.layers.Layer = batch_norm_layer: tf.keras.layers.Layer =
tf.keras.layers.experimental.SyncBatchNormalization, tf.keras.layers.BatchNormalization,
batch_norm_momentum: float = 0.99, batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3, batch_norm_epsilon: float = 1e-3,
**kwargs): **kwargs):
...@@ -1031,8 +1031,8 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -1031,8 +1031,8 @@ class MovinetBlock(tf.keras.layers.Layer):
kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal', kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras
.regularizers.L2(KERNEL_WEIGHT_DECAY), .regularizers.L2(KERNEL_WEIGHT_DECAY),
batch_norm_layer: tf.keras.layers.Layer = tf.keras.layers.experimental batch_norm_layer: tf.keras.layers.Layer =
.SyncBatchNormalization, tf.keras.layers.BatchNormalization,
batch_norm_momentum: float = 0.99, batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3, batch_norm_epsilon: float = 1e-3,
state_prefix: Optional[str] = None, state_prefix: Optional[str] = None,
...@@ -1232,8 +1232,8 @@ class Stem(tf.keras.layers.Layer): ...@@ -1232,8 +1232,8 @@ class Stem(tf.keras.layers.Layer):
kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal', kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras
.regularizers.L2(KERNEL_WEIGHT_DECAY), .regularizers.L2(KERNEL_WEIGHT_DECAY),
batch_norm_layer: tf.keras.layers.Layer = tf.keras.layers.experimental batch_norm_layer: tf.keras.layers.Layer =
.SyncBatchNormalization, tf.keras.layers.BatchNormalization,
batch_norm_momentum: float = 0.99, batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3, batch_norm_epsilon: float = 1e-3,
state_prefix: Optional[str] = None, state_prefix: Optional[str] = None,
...@@ -1340,8 +1340,8 @@ class Head(tf.keras.layers.Layer): ...@@ -1340,8 +1340,8 @@ class Head(tf.keras.layers.Layer):
kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal', kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras
.regularizers.L2(KERNEL_WEIGHT_DECAY), .regularizers.L2(KERNEL_WEIGHT_DECAY),
batch_norm_layer: tf.keras.layers.Layer = tf.keras.layers.experimental batch_norm_layer: tf.keras.layers.Layer =
.SyncBatchNormalization, tf.keras.layers.BatchNormalization,
batch_norm_momentum: float = 0.99, batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3, batch_norm_epsilon: float = 1e-3,
state_prefix: Optional[str] = None, state_prefix: Optional[str] = None,
...@@ -1470,6 +1470,7 @@ class ClassifierHead(tf.keras.layers.Layer): ...@@ -1470,6 +1470,7 @@ class ClassifierHead(tf.keras.layers.Layer):
self._num_classes = num_classes self._num_classes = num_classes
self._dropout_rate = dropout_rate self._dropout_rate = dropout_rate
self._conv_type = conv_type self._conv_type = conv_type
self._activation = activation
self._output_activation = output_activation self._output_activation = output_activation
self._max_pool_predictions = max_pool_predictions self._max_pool_predictions = max_pool_predictions
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
...@@ -1509,6 +1510,7 @@ class ClassifierHead(tf.keras.layers.Layer): ...@@ -1509,6 +1510,7 @@ class ClassifierHead(tf.keras.layers.Layer):
'num_classes': self._num_classes, 'num_classes': self._num_classes,
'dropout_rate': self._dropout_rate, 'dropout_rate': self._dropout_rate,
'conv_type': self._conv_type, 'conv_type': self._conv_type,
'activation': self._activation,
'output_activation': self._output_activation, 'output_activation': self._output_activation,
'max_pool_predictions': self._max_pool_predictions, 'max_pool_predictions': self._max_pool_predictions,
'kernel_initializer': self._kernel_initializer, 'kernel_initializer': self._kernel_initializer,
......
...@@ -36,6 +36,7 @@ class MovinetClassifier(tf.keras.Model): ...@@ -36,6 +36,7 @@ class MovinetClassifier(tf.keras.Model):
backbone: tf.keras.Model, backbone: tf.keras.Model,
num_classes: int, num_classes: int,
input_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None, input_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None,
activation: str = 'swish',
dropout_rate: float = 0.0, dropout_rate: float = 0.0,
kernel_initializer: str = 'HeNormal', kernel_initializer: str = 'HeNormal',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
...@@ -48,6 +49,7 @@ class MovinetClassifier(tf.keras.Model): ...@@ -48,6 +49,7 @@ class MovinetClassifier(tf.keras.Model):
backbone: A 3d backbone network. backbone: A 3d backbone network.
num_classes: Number of classes in classification task. num_classes: Number of classes in classification task.
input_specs: Specs of the input tensor. input_specs: Specs of the input tensor.
activation: name of the main activation function.
dropout_rate: Rate for dropout regularization. dropout_rate: Rate for dropout regularization.
kernel_initializer: Kernel initializer for the final dense layer. kernel_initializer: Kernel initializer for the final dense layer.
kernel_regularizer: Kernel regularizer. kernel_regularizer: Kernel regularizer.
...@@ -65,6 +67,7 @@ class MovinetClassifier(tf.keras.Model): ...@@ -65,6 +67,7 @@ class MovinetClassifier(tf.keras.Model):
self._num_classes = num_classes self._num_classes = num_classes
self._input_specs = input_specs self._input_specs = input_specs
self._activation = activation
self._dropout_rate = dropout_rate self._dropout_rate = dropout_rate
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
...@@ -151,7 +154,8 @@ class MovinetClassifier(tf.keras.Model): ...@@ -151,7 +154,8 @@ class MovinetClassifier(tf.keras.Model):
dropout_rate=self._dropout_rate, dropout_rate=self._dropout_rate,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
conv_type=backbone.conv_type)( conv_type=backbone.conv_type,
activation=self._activation)(
x) x)
outputs = (x, states) if self._output_states else x outputs = (x, states) if self._output_states else x
...@@ -180,6 +184,7 @@ class MovinetClassifier(tf.keras.Model): ...@@ -180,6 +184,7 @@ class MovinetClassifier(tf.keras.Model):
def get_config(self): def get_config(self):
config = { config = {
'backbone': self._backbone, 'backbone': self._backbone,
'activation': self._activation,
'num_classes': self._num_classes, 'num_classes': self._num_classes,
'input_specs': self._input_specs, 'input_specs': self._input_specs,
'dropout_rate': self._dropout_rate, 'dropout_rate': self._dropout_rate,
...@@ -226,6 +231,7 @@ def build_movinet_model( ...@@ -226,6 +231,7 @@ def build_movinet_model(
num_classes=num_classes, num_classes=num_classes,
kernel_regularizer=l2_regularizer, kernel_regularizer=l2_regularizer,
input_specs=input_specs_dict, input_specs=input_specs_dict,
activation=model_config.activation,
dropout_rate=model_config.dropout_rate, dropout_rate=model_config.dropout_rate,
output_states=model_config.output_states) output_states=model_config.output_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment