Commit 063cef34 authored by SunJong Park's avatar SunJong Park Committed by A. Unique TensorFlower
Browse files

Copybara import of the project:

--
f544441ff132c14c0d95026181eb12c69b4eac1d by SunJong Park <53969182+ryan0507@users.noreply.github.com>:

Assemblenet++ Migration with TF2 (#10288)

* Assemblenet++ implementation with TF2 (UCF101 Dataset)

* pylint.sh passed

* train.py document updated

* README.md updated - AssembleNet and AssembleNet++

* YAML and configuration updated - AssembleNet and AssembleNet++

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/models/pull/10355 from tensorflow:assemblenet f544441ff132c14c0d95026181eb12c69b4eac1d
PiperOrigin-RevId: 414047541
parent 024ebd81
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
This repository is the official implementations of the following papers. This repository is the official implementations of the following papers.
The original implementations could be found in [here](https://github.com/google-research/google-research/tree/master/assemblenet)
[![Paper](http://img.shields.io/badge/Paper-arXiv.2008.03800-B3181B?logo=arXiv)](https://arxiv.org/abs/1905.13209) [![Paper](http://img.shields.io/badge/Paper-arXiv.2008.03800-B3181B?logo=arXiv)](https://arxiv.org/abs/1905.13209)
[AssembleNet: Searching for Multi-Stream Neural Connectivity in Video [AssembleNet: Searching for Multi-Stream Neural Connectivity in Video
Architectures](https://arxiv.org/abs/1905.13209) Architectures](https://arxiv.org/abs/1905.13209)
...@@ -10,5 +12,130 @@ Architectures](https://arxiv.org/abs/1905.13209) ...@@ -10,5 +12,130 @@ Architectures](https://arxiv.org/abs/1905.13209)
[AssembleNet++: Assembling Modality Representations via Attention [AssembleNet++: Assembling Modality Representations via Attention
Connections](https://arxiv.org/abs/2008.08072) Connections](https://arxiv.org/abs/2008.08072)
**DISCLAIMER**: AssembleNet++ implementation is still under development.
No support will be provided during the development phase. DISCLAIMER: AssembleNet++ implementation is still under development. No support
will be provided during the development phase.
## Description
### AssembleNet vs. AssembleNet++
AssembleNet and AssembleNet++ both focus on neural connectivity search for
multi-stream video CNN architectures. They learn weights for the connections
between multiple convolutional blocks (composed of (2+1)D or 3D residual
modules) organized sequentially or in parallel, thereby optimizing the neural
architecture for the data/task.
AssembleNet++ adds *peer-attention* to the basic AssembleNet, which allows each
conv. block connection to be conditioned differently based on another block. It
is a form of channel-wise attention, which we found to be beneficial.
<img width="1158" alt="peer_attention" src="https://user-images.githubusercontent.com/53969182/135665233-e64ccda1-7dd3-45f2-9d77-5c4515703f13.png">
The code is provided in [assemblenet.py](modeling/assemblenet.py) and
[assemblenet_plus.py](modeling/assemblenet_plus.py). Notice that the provided
code uses (2+1)D residual modules as the building blocks of AssembleNet/++, but
you can use your own module while still benefitting from the connectivity search
of AssembleNet/++.
### Neural Architecture Search
As you will find from the [AssembleNet](https://arxiv.org/abs/1905.13209) paper,
the models we provide in [config files](configs/assemblenet.py) are the
result of architecture search/learning.
The architecture search in AssembleNet (and AssembleNet++) has two components:
(i) convolutional block configuration search using an evolutionary algorithm,
and (ii) one-shot differentiable connection search. We did not include the code
for the first part (i.e., evolution), as it relies on another infrastructure and
more computation. The 2nd part (i.e., differentiable search) is included in the
code however, which will allow you to use to code to search for the best
connectivity for your own models.
That is, as also described in the
[AssembleNet++](https://arxiv.org/abs/2008.08072) paper, once the convolutional
blocks are decided based on the search or manually, you can use the provide code
to obtain the best block connections and learn attention connectivity in a
one-shot differentiable way. You just need to train the network (with
`FLAGS.model_edge_weights` as `[]`) and the connectivity search will be done
simultaneously.
### AssembleNet and AssembleNet++ Structure Format
The format we use to specify AssembleNet/++ architectures is as follows: It is a
`list` corresponding to a graph representation of the network, where a node is a
convolutional block and an edge specifies a connection from one block to
another. Each node itself (in the structure list) is a `list` with the following
format: `[block_level, [list_of_input_blocks], number_filter, temporal_dilation,
spatial_stride]`. `[list_of_input_blocks]` should be the list of node indexes
whose values are less than the index of the node itself. The 'stems' of the
network directly taking raw inputs follow a different node format: `[stem_type,
temporal_dilation]`. The stem_type is -1 for RGB stem and is -2 for optical flow
stem. The stem_type -3 is reserved for the object segmentation input.
In AssembleNet++lite, instead of passing a single `int` for `number_filter`, we
pass a list/tuple of three `int`s. They specify the number of channels to be
used for each layer in the inverted bottleneck modules.
### Optical Flow and Data Loading
Instead of loading optical flows as inputs from data pipeline, we are applying
the
[Representation Flow](https://github.com/piergiaj/representation-flow-cvpr19) to
RGB frames so that we can compute the flow within TPU/GPU on fly. It's
essentially optical flow since it is computed directly from RGBs. The benefit is
that we don't need an external optical flow extraction and data loading. You
only need to feed RGB, and the flow will be computed internally.
## History
2021/10/02 : AssembleNet, AssembleNet++ implementation with UCF101 dataset
provided
## Authors
* SunJong Park ([@GitHub ryan0507](https://github.com/ryan0507))
* HyeYoon Lee ([@GitHub hylee817](https://github.com/hylee817))
## Table of Contents
* [AssembleNet vs AssembleNet++](#assemblenet-vs-assemblenet)
* [Neural Architecture Search](#neural-architecture-search)
* [AssembleNet and AssembleNet++ Structure Format](#assemblenet-and-assemblenet-structure-format)
* [Optical Flow and Data Loading](#optical-flow-and-data-loading)
## Requirements
[![TensorFlow 2.2](https://img.shields.io/badge/TensorFlow-2.5.0-FF6F00?logo=tensorflow)](https://github.com/tensorflow/tensorflow/releases/tag/v2.5.0)
[![Python 3.8](https://img.shields.io/badge/Python-3.8-3776AB)](https://www.python.org/downloads/release/python-380/)
## Training and Evaluation
Example of training AssembleNet with UCF101 TF Datasets.
```bash
python -m official.vision.beta.projects.assemblenet.trian \
--mode=train_and_eval --experiment=assemblenet_ucf101 \
--model_dir='YOUR_GS_BUCKET_TO_SAVE_MODEL' \
--config_file=./official/vision/beta/projects/assemblenet/\
--ucf101_assemblenet_tpu.yaml \
--tpu=TPU_NAME
```
Example of training AssembleNet++ with UCF101 TF Datasets.
```bash
python -m official.vision.beta.projects.assemblenet.trian \
--mode=train_and_eval --experiment=assemblenetplus_ucf101 \
--model_dir='YOUR_GS_BUCKET_TO_SAVE_MODEL' \
--config_file=./official/vision/beta/projects/assemblenet/\
--ucf101_assemblenet_plus_tpu.yaml \
--tpu=TPU_NAME
```
Currently, we provide experiments with kinetics400, kinetics500, kinetics600,
UCF101 datasets. If you want to add a new experiment you should modify
exp_factory for configuration.
...@@ -34,8 +34,9 @@ used for each layer in the inverted bottleneck modules. ...@@ -34,8 +34,9 @@ used for each layer in the inverted bottleneck modules.
The structure_weights specify the learned connection weights. The structure_weights specify the learned connection weights.
""" """
from typing import List, Tuple
import dataclasses import dataclasses
from typing import List, Tuple, Optional
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
...@@ -176,26 +177,47 @@ class AssembleNet(hyperparams.Config): ...@@ -176,26 +177,47 @@ class AssembleNet(hyperparams.Config):
blocks: Tuple[BlockSpec, ...] = tuple() blocks: Tuple[BlockSpec, ...] = tuple()
@dataclasses.dataclass
class AssembleNetPlus(hyperparams.Config):
model_id: str = '50'
num_frames: int = 0
attention_mode: str = 'None'
blocks: Tuple[BlockSpec, ...] = tuple()
use_object_input: bool = False
@dataclasses.dataclass @dataclasses.dataclass
class Backbone3D(backbones_3d.Backbone3D): class Backbone3D(backbones_3d.Backbone3D):
"""Configuration for backbones. """Configuration for backbones.
Attributes: Attributes:
type: 'str', type of backbone be used, on the of fields below. type: 'str', type of backbone be used, on the of fields below.
resnet: resnet3d backbone config. assemblenet: AssembleNet backbone config.
assemblenet_plus : AssembleNetPlus backbone config.
""" """
type: str = 'assemblenet' type: Optional[str] = None
assemblenet: AssembleNet = AssembleNet() assemblenet: AssembleNet = AssembleNet()
assemblenet_plus: AssembleNetPlus = AssembleNetPlus()
@dataclasses.dataclass @dataclasses.dataclass
class AssembleNetModel(video_classification.VideoClassificationModel): class AssembleNetModel(video_classification.VideoClassificationModel):
"""The AssembleNet model config.""" """The AssembleNet model config."""
model_type: str = 'assemblenet' model_type: str = 'assemblenet'
backbone: Backbone3D = Backbone3D() backbone: Backbone3D = Backbone3D(type='assemblenet')
norm_activation: common.NormActivation = common.NormActivation( norm_activation: common.NormActivation = common.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=True) norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=True)
max_pool_preditions: bool = False max_pool_predictions: bool = False
@dataclasses.dataclass
class AssembleNetPlusModel(video_classification.VideoClassificationModel):
"""The AssembleNet model config."""
model_type: str = 'assemblenet_plus'
backbone: Backbone3D = Backbone3D(type='assemblenet_plus')
norm_activation: common.NormActivation = common.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=True)
max_pool_predictions: bool = False
@exp_factory.register_config_factory('assemblenet50_kinetics600') @exp_factory.register_config_factory('assemblenet50_kinetics600')
...@@ -223,3 +245,43 @@ def assemblenet_kinetics600() -> cfg.ExperimentConfig: ...@@ -223,3 +245,43 @@ def assemblenet_kinetics600() -> cfg.ExperimentConfig:
f'{exp.task.model.backbone.assemblenet}') f'{exp.task.model.backbone.assemblenet}')
return exp return exp
@exp_factory.register_config_factory('assemblenet_ucf101')
def assemblenet_ucf101() -> cfg.ExperimentConfig:
"""Video classification on Videonet with assemblenet."""
exp = video_classification.video_classification_ucf101()
exp.task.train_data.dtype = 'bfloat16'
exp.task.validation_data.dtype = 'bfloat16'
feature_shape = (32, 224, 224, 3)
model = AssembleNetModel()
model.backbone.assemblenet.blocks = flat_lists_to_blocks(
asn50_structure, asn_structure_weights)
model.backbone.assemblenet.num_frames = feature_shape[0]
exp.task.model = model
assert exp.task.model.backbone.assemblenet.num_frames > 0, (
f'backbone num_frames '
f'{exp.task.model.backbone.assemblenet}')
return exp
@exp_factory.register_config_factory('assemblenetplus_ucf101')
def assemblenetplus_ucf101() -> cfg.ExperimentConfig:
"""Video classification on Videonet with assemblenet."""
exp = video_classification.video_classification_ucf101()
exp.task.train_data.dtype = 'bfloat16'
exp.task.validation_data.dtype = 'bfloat16'
feature_shape = (32, 224, 224, 3)
model = AssembleNetPlusModel()
model.backbone.assemblenet_plus.blocks = flat_lists_to_blocks(
asn50_structure, asn_structure_weights)
model.backbone.assemblenet_plus.num_frames = feature_shape[0]
exp.task.model = model
assert exp.task.model.backbone.assemblenet_plus.num_frames > 0, (
f'backbone num_frames '
f'{exp.task.model.backbone.assemblenet_plus}')
return exp
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
from absl.testing import parameterized
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.vision.beta.configs import video_classification as exp_cfg
from official.vision.beta.projects.assemblenet.configs import assemblenet
class AssemblenetTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
('assemblenet50_kinetics600',),)
def test_assemblenet_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.VideoClassificationTask)
self.assertIsInstance(config.task.model, assemblenet.AssembleNetModel)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()
def test_configs_conversion(self):
blocks = assemblenet.flat_lists_to_blocks(assemblenet.asn50_structure,
assemblenet.asn_structure_weights)
re_structure, re_weights = assemblenet.blocks_to_flat_lists(blocks)
self.assertAllEqual(
re_structure, assemblenet.asn50_structure, msg='asn50_structure')
self.assertAllEqual(
re_weights,
assemblenet.asn_structure_weights,
msg='asn_structure_weights')
if __name__ == '__main__':
tf.test.main()
# Assemblenet++ structure video classificaion on UCF-101 dataset
# --experiment_type=assemblenetplus_ucf101
# device : TPU v3-8
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
model:
backbone:
assemblenet_plus:
model_id: 50
num_frames: 32
attention_mode: 'peer'
use_object_input: false
type: 'assemblenet_plus'
dropout_rate: 0.5
norm_activation:
activation: relu
norm_momentum: 0.99
norm_epsilon: 0.00001
use_sync_bn: true
max_pool_predictions: true
train_data:
is_training: true
global_batch_size: 64
dtype: 'bfloat16'
tfds_data_dir: 'gs://oss-yonsei/tensorflow_datasets/'
validation_data:
is_training: false
global_batch_size: 64
dtype: 'bfloat16'
tfds_data_dir: 'gs://oss-yonsei/tensorflow_datasets/'
drop_remainder: true
trainer:
train_steps: 900000 # 500 epochs
validation_steps: 144
validation_interval: 144
steps_per_loop: 144 # NUM_EXAMPLES (9537) // global_batch_size
summary_interval: 144
checkpoint_interval: 144
optimizer_config:
optimizer:
type: 'sgd'
sgd:
momentum: 0.9
learning_rate:
type: 'exponential'
exponential:
initial_learning_rate: 0.008 # 0.008 * batch_size / 128
decay_steps: 532 # 2.5 * steps_per_epoch
decay_rate: 0.96
staircase: true
warmup:
type: 'linear'
linear:
warmup_steps: 50
# Assemblenet structure video classificaion on UCF-101 dataset
# --experiment_type=assemblenet_ucf101
# device : TPU v3-8
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
model:
backbone:
assemblenet:
model_id: 101
num_frames: 32
combine_method: 'sigmoid'
type: 'assemblenet'
dropout_rate: 0.5
norm_activation:
activation: relu
norm_momentum: 0.99
norm_epsilon: 0.00001
use_sync_bn: true
max_pool_predictions: true
train_data:
is_training: true
global_batch_size: 32
dtype: 'bfloat16'
tfds_data_dir: 'gs://oss-yonsei/tensorflow_datasets/'
validation_data:
is_training: false
global_batch_size: 32
dtype: 'bfloat16'
tfds_data_dir: 'gs://oss-yonsei/tensorflow_datasets/'
drop_remainder: true
trainer:
train_steps: 90000 # 500 epochs
validation_steps: 288
validation_interval: 288
steps_per_loop: 288 # NUM_EXAMPLES (9537) // global_batch_size
summary_interval: 288
checkpoint_interval: 288
optimizer_config:
optimizer:
type: 'sgd'
sgd:
momentum: 0.9
learning_rate:
type: 'exponential'
exponential:
initial_learning_rate: 0.008 # 0.008 * batch_size / 128
decay_steps: 1024 # 2.5 * steps_per_epoch
decay_rate: 0.96
staircase: true
warmup:
type: 'linear'
linear:
warmup_steps: 50
...@@ -686,7 +686,7 @@ def multi_stream_heads(streams, ...@@ -686,7 +686,7 @@ def multi_stream_heads(streams,
final_nodes, final_nodes,
num_frames, num_frames,
num_classes, num_classes,
max_pool_preditions: bool = False): max_pool_predictions: bool = False):
"""Layers for the classification heads. """Layers for the classification heads.
Args: Args:
...@@ -694,7 +694,7 @@ def multi_stream_heads(streams, ...@@ -694,7 +694,7 @@ def multi_stream_heads(streams,
final_nodes: A list of `int` where classification heads will be added. final_nodes: A list of `int` where classification heads will be added.
num_frames: `int` number of frames in the input tensor. num_frames: `int` number of frames in the input tensor.
num_classes: `int` number of possible classes for video classification. num_classes: `int` number of possible classes for video classification.
max_pool_preditions: Use max-pooling on predictions instead of mean max_pool_predictions: Use max-pooling on predictions instead of mean
pooling on features. It helps if you have more than 32 frames. pooling on features. It helps if you have more than 32 frames.
Returns: Returns:
...@@ -709,7 +709,7 @@ def multi_stream_heads(streams, ...@@ -709,7 +709,7 @@ def multi_stream_heads(streams,
net = tf.identity(net, 'final_avg_pool0') net = tf.identity(net, 'final_avg_pool0')
net = tf.reshape(net, [-1, num_frames, num_channels]) net = tf.reshape(net, [-1, num_frames, num_channels])
if not max_pool_preditions: if not max_pool_predictions:
net = tf.reduce_mean(net, 1) net = tf.reduce_mean(net, 1)
return net return net
...@@ -730,7 +730,7 @@ def multi_stream_heads(streams, ...@@ -730,7 +730,7 @@ def multi_stream_heads(streams,
kernel_initializer=tf.random_normal_initializer(stddev=.01))( kernel_initializer=tf.random_normal_initializer(stddev=.01))(
inputs=outputs) inputs=outputs)
outputs = tf.identity(outputs, 'final_dense0') outputs = tf.identity(outputs, 'final_dense0')
if max_pool_preditions: if max_pool_predictions:
pre_logits = outputs / np.sqrt(num_frames) pre_logits = outputs / np.sqrt(num_frames)
acts = tf.nn.softmax(pre_logits, axis=1) acts = tf.nn.softmax(pre_logits, axis=1)
outputs = tf.math.multiply(outputs, acts) outputs = tf.math.multiply(outputs, acts)
...@@ -894,7 +894,7 @@ class AssembleNetModel(tf.keras.Model): ...@@ -894,7 +894,7 @@ class AssembleNetModel(tf.keras.Model):
model_structure: List[Any], model_structure: List[Any],
input_specs: Optional[Mapping[str, input_specs: Optional[Mapping[str,
tf.keras.layers.InputSpec]] = None, tf.keras.layers.InputSpec]] = None,
max_pool_preditions: bool = False, max_pool_predictions: bool = False,
**kwargs): **kwargs):
if not input_specs: if not input_specs:
input_specs = { input_specs = {
...@@ -924,7 +924,7 @@ class AssembleNetModel(tf.keras.Model): ...@@ -924,7 +924,7 @@ class AssembleNetModel(tf.keras.Model):
grouping[3], grouping[3],
num_frames, num_frames,
num_classes, num_classes,
max_pool_preditions=max_pool_preditions) max_pool_predictions=max_pool_predictions)
super(AssembleNetModel, self).__init__( super(AssembleNetModel, self).__init__(
inputs=inputs, outputs=outputs, **kwargs) inputs=inputs, outputs=outputs, **kwargs)
...@@ -981,7 +981,7 @@ def assemblenet_v1(assemblenet_depth: int, ...@@ -981,7 +981,7 @@ def assemblenet_v1(assemblenet_depth: int,
input_specs: layers.InputSpec = layers.InputSpec( input_specs: layers.InputSpec = layers.InputSpec(
shape=[None, None, None, None, 3]), shape=[None, None, None, None, 3]),
model_edge_weights: Optional[List[Any]] = None, model_edge_weights: Optional[List[Any]] = None,
max_pool_preditions: bool = False, max_pool_predictions: bool = False,
combine_method: str = 'sigmoid', combine_method: str = 'sigmoid',
**kwargs): **kwargs):
"""Returns the AssembleNet model for a given size and number of output classes.""" """Returns the AssembleNet model for a given size and number of output classes."""
...@@ -1009,7 +1009,7 @@ def assemblenet_v1(assemblenet_depth: int, ...@@ -1009,7 +1009,7 @@ def assemblenet_v1(assemblenet_depth: int,
num_frames=num_frames, num_frames=num_frames,
model_structure=model_structure, model_structure=model_structure,
input_specs=input_specs_dict, input_specs=input_specs_dict,
max_pool_preditions=max_pool_preditions, max_pool_predictions=max_pool_predictions,
**kwargs) **kwargs)
...@@ -1025,7 +1025,7 @@ def build_assemblenet_v1( ...@@ -1025,7 +1025,7 @@ def build_assemblenet_v1(
backbone_type = backbone_config.type backbone_type = backbone_config.type
backbone_cfg = backbone_config.get() backbone_cfg = backbone_config.get()
assert backbone_type == 'assemblenet' assert 'assemblenet' in backbone_type
assemblenet_depth = int(backbone_cfg.model_id) assemblenet_depth = int(backbone_cfg.model_id)
if assemblenet_depth not in ASSEMBLENET_SPECS: if assemblenet_depth not in ASSEMBLENET_SPECS:
...@@ -1072,5 +1072,5 @@ def build_assemblenet_model( ...@@ -1072,5 +1072,5 @@ def build_assemblenet_model(
num_frames=backbone_cfg.num_frames, num_frames=backbone_cfg.num_frames,
model_structure=model_structure, model_structure=model_structure,
input_specs=input_specs_dict, input_specs=input_specs_dict,
max_pool_preditions=model_config.max_pool_preditions) max_pool_predictions=model_config.max_pool_predictions)
return model return model
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding=utf-8
# Copyright 2021 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains definitions for the AssembleNet++ [2] models (without object input).
Requires the AssembleNet++ architecture to be specified in
FLAGS.model_structure (and optionally FLAGS.model_edge_weights). This is
identical to the form described in assemblenet.py for the AssembleNet. Please
check assemblenet.py for the detailed format of the model strings.
AssembleNet++ adds `peer-attention' to the basic AssembleNet, which allows each
conv. block connection to be conditioned differently based on another block [2].
It is a form of channel-wise attention. Note that we learn to apply attention
independently for each frame.
The `peer-attention' implementation in this file is the version that enables
one-shot differentiable search of attention connectivity (Fig. 2 in [2]), using
a softmax weighted summation of possible attention vectors.
[2] Michael S. Ryoo, AJ Piergiovanni, Juhana Kangaspunta, Anelia Angelova,
AssembleNet++: Assembling Modality Representations via Attention
Connections. ECCV 2020
https://arxiv.org/abs/2008.08072
In order to take advantage of object inputs, one will need to set the flag
FLAGS.use_object_input as True, and provide the list of input tensors as an
input to the network, as shown in run_asn_with_object.py. This will require a
pre-processed object data stream.
It uses (2+1)D convolutions for video representations. The main AssembleNet++
takes a 4-D (N*T)HWC tensor as an input (i.e., the batch dim and time dim are
mixed), and it reshapes a tensor to NT(H*W)C whenever a 1-D temporal conv. is
necessary. This is to run this on TPU efficiently.
"""
import functools
from typing import Any, Dict, Mapping, List, Optional
from absl import logging
import numpy as np
import tensorflow as tf
from official.modeling import hyperparams
from official.vision.beta.modeling import factory_3d as model_factory
from official.vision.beta.modeling.backbones import factory as backbone_factory
from official.vision.beta.projects.assemblenet.configs import assemblenet as cfg
from official.vision.beta.projects.assemblenet.modeling import assemblenet as asn
from official.vision.beta.projects.assemblenet.modeling import rep_flow_2d_layer as rf
layers = tf.keras.layers
def softmax_merge_peer_attentions(peers):
"""Merge multiple peer-attention vectors with softmax weighted sum.
Summation weights are to be learned.
Args:
peers: A list of `Tensors` of size `[batch*time, channels]`.
Returns:
The output `Tensor` of size `[batch*time, channels].
"""
data_format = tf.keras.backend.image_data_format()
dtype = peers[0].dtype
assert data_format == 'channels_last'
initial_attn_weights = tf.keras.initializers.TruncatedNormal(stddev=0.01)(
[len(peers)])
attn_weights = tf.cast(tf.nn.softmax(initial_attn_weights), dtype)
weighted_peers = []
for i, peer in enumerate(peers):
weighted_peers.append(attn_weights[i] * peer)
return tf.add_n(weighted_peers)
def apply_attention(inputs,
attention_mode=None,
attention_in=None,
use_5d_mode=False):
"""Applies peer-attention or self-attention to the input tensor.
Depending on the attention_mode, this function either applies channel-wise
self-attention or peer-attention. For the peer-attention, the function
combines multiple candidate attention vectors (given as attention_in), by
learning softmax-sum weights described in the AssembleNet++ paper. Note that
the attention is applied individually for each frame, which showed better
accuracies than using video-level attention.
Args:
inputs: A `Tensor`. Either 4D or 5D, depending of use_5d_mode.
attention_mode: `str` specifying mode. If not `peer', does self-attention.
attention_in: A list of `Tensors' of size [batch*time, channels].
use_5d_mode: `bool` indicating whether the inputs are in 5D tensor or 4D.
Returns:
The output `Tensor` after concatenation.
"""
data_format = tf.keras.backend.image_data_format()
assert data_format == 'channels_last'
if use_5d_mode:
h_channel_loc = 2
else:
h_channel_loc = 1
if attention_mode == 'peer':
attn = softmax_merge_peer_attentions(attention_in)
else:
attn = tf.math.reduce_mean(inputs, [h_channel_loc, h_channel_loc + 1])
attn = tf.keras.layers.Dense(
units=inputs.shape[-1],
kernel_initializer=tf.random_normal_initializer(stddev=.01))(
inputs=attn)
attn = tf.math.sigmoid(attn)
channel_attn = tf.expand_dims(
tf.expand_dims(attn, h_channel_loc), h_channel_loc)
inputs = tf.math.multiply(inputs, channel_attn)
return inputs
class _ApplyEdgeWeight(layers.Layer):
"""Multiply weight on each input tensor.
A weight is assigned for each connection (i.e., each input tensor). This layer
is used by the fusion_with_peer_attention to compute the weighted inputs.
"""
def __init__(self,
weights_shape,
index: Optional[int] = None,
use_5d_mode: bool = False,
model_edge_weights: Optional[List[Any]] = None,
num_object_classes: Optional[int] = None,
**kwargs):
"""Constructor.
Args:
weights_shape: A list of intergers. Each element means number of edges.
index: `int` index of the block within the AssembleNet architecture. Used
for summation weight initial loading.
use_5d_mode: `bool` indicating whether the inputs are in 5D tensor or 4D.
model_edge_weights: AssembleNet++ model structure connection weights in
the string format.
num_object_classes: Assemblenet++ structure used object inputs so we
should use what dataset classes you might be use (e.g. ADE-20k 151
classes)
**kwargs: pass through arguments.
Returns:
The output `Tensor` after concatenation.
"""
super(_ApplyEdgeWeight, self).__init__(**kwargs)
self._weights_shape = weights_shape
self._index = index
self._use_5d_mode = use_5d_mode
self._model_edge_weights = model_edge_weights
self._num_object_classes = num_object_classes
data_format = tf.keras.backend.image_data_format()
assert data_format == 'channels_last'
def get_config(self):
config = {
'weights_shape': self._weights_shape,
'index': self._index,
'use_5d_mode': self._use_5d_mode,
'model_edge_weights': self._model_edge_weights,
'num_object_classes': self._num_object_classes
}
base_config = super(_ApplyEdgeWeight, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape: tf.TensorShape):
if self._weights_shape[0] == 1:
self._edge_weights = 1.0
return
if self._index is None or not self._model_edge_weights:
self._edge_weights = self.add_weight(
shape=self._weights_shape,
initializer=tf.keras.initializers.TruncatedNormal(
mean=0.0, stddev=0.01),
trainable=True,
name='agg_weights')
else:
initial_weights_after_sigmoid = np.asarray(
self._model_edge_weights[self._index][0]).astype('float32')
# Initial_weights_after_sigmoid is never 0, as the initial weights are
# based the results of a successful connectivity search.
initial_weights = -np.log(1. / initial_weights_after_sigmoid - 1.)
self._edge_weights = self.add_weight(
shape=self._weights_shape,
initializer=tf.constant_initializer(initial_weights),
trainable=False,
name='agg_weights')
def call(self,
inputs: List[tf.Tensor],
training: Optional[bool] = None) -> Mapping[Any, List[tf.Tensor]]:
use_5d_mode = self._use_5d_mode
dtype = inputs[0].dtype
assert len(inputs) > 1
if use_5d_mode:
h_channel_loc = 2
else:
h_channel_loc = 1
# get smallest spatial size and largest channels
sm_size = [10000, 10000]
lg_channel = 0
for inp in inputs:
# assume batch X height x width x channels
sm_size[0] = min(sm_size[0], inp.shape[h_channel_loc])
sm_size[1] = min(sm_size[1], inp.shape[h_channel_loc + 1])
# Note that, when using object inputs, object channel sizes are usually
# big. Since we do not want the object channel size to increase the number
# of parameters for every fusion, we exclude it when computing lg_channel.
if inp.shape[-1] > lg_channel and inp.shape[-1] != self._num_object_classes: # pylint: disable=line-too-long
lg_channel = inp.shape[3]
# loads or creates weight variables to fuse multiple inputs
weights = tf.math.sigmoid(tf.cast(self._edge_weights, dtype))
# Compute weighted inputs. We group inputs with the same channels.
per_channel_inps = dict({0: []})
for i, inp in enumerate(inputs):
if inp.shape[h_channel_loc] != sm_size[0] or inp.shape[h_channel_loc + 1] != sm_size[1]: # pylint: disable=line-too-long
assert sm_size[0] != 0
ratio = (inp.shape[h_channel_loc] + 1) // sm_size[0]
if use_5d_mode:
inp = tf.keras.layers.MaxPool3D([1, ratio, ratio], [1, ratio, ratio],
padding='same')(
inp)
else:
inp = tf.keras.layers.MaxPool2D([ratio, ratio], ratio,
padding='same')(
inp)
weights = tf.cast(weights, inp.dtype)
if inp.shape[-1] in per_channel_inps:
per_channel_inps[inp.shape[-1]].append(weights[i] * inp)
else:
per_channel_inps.update({inp.shape[-1]: [weights[i] * inp]})
return per_channel_inps
def fusion_with_peer_attention(inputs: List[tf.Tensor],
index: Optional[int] = None,
attention_mode: Optional[str] = None,
attention_in: Optional[List[tf.Tensor]] = None,
use_5d_mode: bool = False,
model_edge_weights: Optional[List[Any]] = None,
num_object_classes: Optional[int] = None):
"""Weighted summation of multiple tensors, while using peer-attention.
Summation weights are to be learned. Uses spatial max pooling and 1x1 conv.
to match their sizes. Before the summation, each connection (i.e., each input)
itself is scaled with channel-wise peer-attention. Notice that attention is
applied for each connection, conditioned based on attention_in.
Args:
inputs: A list of `Tensors`. Either 4D or 5D, depending of use_5d_mode.
index: `int` index of the block within the AssembleNet architecture. Used
for summation weight initial loading.
attention_mode: `str` specifying mode. If not `peer', does self-attention.
attention_in: A list of `Tensors' of size [batch*time, channels].
use_5d_mode: `bool` indicating whether the inputs are in 5D tensor or 4D.
model_edge_weights: AssembleNet model structure connection weights in the
string format.
num_object_classes: Assemblenet++ structure used object inputs so we should
use what dataset classes you might be use (e.g. ADE-20k 151 classes)
Returns:
The output `Tensor` after concatenation.
"""
if use_5d_mode:
h_channel_loc = 2
conv_function = asn.conv3d_same_padding
else:
h_channel_loc = 1
conv_function = asn.conv2d_fixed_padding
# If only 1 input.
if len(inputs) == 1:
inputs[0] = apply_attention(inputs[0], attention_mode, attention_in,
use_5d_mode)
return inputs[0]
# get smallest spatial size and largest channels
sm_size = [10000, 10000]
lg_channel = 0
for inp in inputs:
# assume batch X height x width x channels
sm_size[0] = min(sm_size[0], inp.shape[h_channel_loc])
sm_size[1] = min(sm_size[1], inp.shape[h_channel_loc + 1])
# Note that, when using object inputs, object channel sizes are usually big.
# Since we do not want the object channel size to increase the number of
# parameters for every fusion, we exclude it when computing lg_channel.
if inp.shape[-1] > lg_channel and inp.shape[-1] != num_object_classes: # pylint: disable=line-too-long
lg_channel = inp.shape[3]
per_channel_inps = _ApplyEdgeWeight(
weights_shape=[len(inputs)],
index=index,
use_5d_mode=use_5d_mode,
model_edge_weights=model_edge_weights)(
inputs)
# Implementation of connectivity with peer-attention
if attention_mode:
for key, channel_inps in per_channel_inps.items():
for idx in range(len(channel_inps)):
with tf.name_scope('Connection_' + str(key) + '_' + str(idx)):
channel_inps[idx] = apply_attention(channel_inps[idx], attention_mode,
attention_in, use_5d_mode)
# Adding 1x1 conv layers (to match channel size) and fusing all inputs.
# We add inputs with the same channels first before applying 1x1 conv to save
# memory.
inps = []
for key, channel_inps in per_channel_inps.items():
if len(channel_inps) < 1:
continue
if len(channel_inps) == 1:
if key == lg_channel:
inp = channel_inps[0]
else:
inp = conv_function(
channel_inps[0], lg_channel, kernel_size=1, strides=1)
inps.append(inp)
else:
if key == lg_channel:
inp = tf.add_n(channel_inps)
else:
inp = conv_function(
channel_inps[0], lg_channel, kernel_size=1, strides=1)
inps.append(inp)
return tf.add_n(inps)
def object_conv_stem(inputs):
"""Layers for an object input stem.
It expects its input tensor to have a separate channel for each object class.
Each channel should be specify each object class.
Args:
inputs: A `Tensor`.
Returns:
The output `Tensor`.
"""
inputs = tf.keras.layers.MaxPool2D(
pool_size=4, strides=4, padding='SAME')(
inputs=inputs)
inputs = tf.identity(inputs, 'initial_max_pool')
return inputs
class AssembleNetPlus(tf.keras.Model):
"""AssembleNet++ backbone."""
def __init__(self,
block_fn,
num_blocks: List[int],
num_frames: int,
model_structure: List[Any],
input_specs: layers.InputSpec = layers.InputSpec(
shape=[None, None, None, None, 3]),
model_edge_weights: Optional[List[Any]] = None,
use_object_input: bool = False,
attention_mode: str = 'peer',
bn_decay: float = rf.BATCH_NORM_DECAY,
bn_epsilon: float = rf.BATCH_NORM_EPSILON,
use_sync_bn: bool = False,
**kwargs):
"""Generator for AssembleNet++ models.
Args:
block_fn: `function` for the block to use within the model. Currently only
has `bottleneck_block_interleave as its option`.
num_blocks: list of 4 `int`s denoting the number of blocks to include in
each of the 4 block groups. Each group consists of blocks that take
inputs of the same resolution.
num_frames: the number of frames in the input tensor.
model_structure: AssembleNetPlus model structure in the string format.
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
Dimension should be `[batch*time, height, width, channels]`.
model_edge_weights: AssembleNet model structure connection weight in the
string format.
use_object_input : 'bool' values whether using object inputs
attention_mode : 'str' , default = 'self', If we use peer attention 'peer'
bn_decay: `float` batch norm decay parameter to use.
bn_epsilon: `float` batch norm epsilon parameter to use.
use_sync_bn: use synchronized batch norm for TPU.
**kwargs: pass through arguments.
Returns:
Model `function` that takes in `inputs` and `is_training` and returns the
output `Tensor` of the AssembleNetPlus model.
"""
data_format = tf.keras.backend.image_data_format()
# Creation of the model graph.
logging.info('model_structure=%r', model_structure)
logging.info('model_structure=%r', model_structure)
logging.info('model_edge_weights=%r', model_edge_weights)
structure = model_structure
if use_object_input:
original_inputs = tf.keras.Input(shape=input_specs[0].shape[1:])
object_inputs = tf.keras.Input(shape=input_specs[1].shape[1:])
input_specs = input_specs[0]
else:
original_inputs = tf.keras.Input(shape=input_specs.shape[1:])
object_inputs = None
original_num_frames = num_frames
assert num_frames > 0, f'Invalid num_frames {num_frames}'
grouping = {-3: [], -2: [], -1: [], 0: [], 1: [], 2: [], 3: []}
for i in range(len(structure)):
grouping[structure[i][0]].append(i)
stem_count = len(grouping[-3]) + len(grouping[-2]) + len(grouping[-1])
assert stem_count != 0
stem_filters = 128 // stem_count
if len(input_specs.shape) == 5:
first_dim = (
input_specs.shape[0] * input_specs.shape[1]
if input_specs.shape[0] and input_specs.shape[1] else -1)
reshape_inputs = tf.reshape(original_inputs,
(first_dim,) + input_specs.shape[2:])
elif len(input_specs.shape) == 4:
reshape_inputs = original_inputs
else:
raise ValueError(
f'Expect input spec to be 4 or 5 dimensions {input_specs.shape}')
if grouping[-2]:
# Instead of loading optical flows as inputs from data pipeline, we are
# applying the "Representation Flow" to RGB frames so that we can compute
# the flow within TPU/GPU on fly. It's essentially optical flow since we
# do it with RGBs.
axis = 3 if data_format == 'channels_last' else 1
flow_inputs = rf.RepresentationFlow(
original_num_frames,
depth=reshape_inputs.shape.as_list()[axis],
num_iter=40,
bottleneck=1)(
reshape_inputs)
streams = []
for i in range(len(structure)):
with tf.name_scope('Node_' + str(i)):
if structure[i][0] == -1:
inputs = asn.rgb_conv_stem(
reshape_inputs,
original_num_frames,
stem_filters,
temporal_dilation=structure[i][1],
bn_decay=bn_decay,
bn_epsilon=bn_epsilon,
use_sync_bn=use_sync_bn)
streams.append(inputs)
elif structure[i][0] == -2:
inputs = asn.flow_conv_stem(
flow_inputs,
stem_filters,
temporal_dilation=structure[i][1],
bn_decay=bn_decay,
bn_epsilon=bn_epsilon,
use_sync_bn=use_sync_bn)
streams.append(inputs)
elif structure[i][0] == -3:
# In order to use the object inputs, you need to feed your object
# input tensor here.
inputs = object_conv_stem(object_inputs)
streams.append(inputs)
else:
block_number = structure[i][0]
combined_inputs = [
streams[structure[i][1][j]]
for j in range(0, len(structure[i][1]))
]
logging.info(grouping)
nodes_below = []
for k in range(-3, structure[i][0]):
nodes_below = nodes_below + grouping[k]
peers = []
if attention_mode:
lg_channel = -1
# To show structures for attention we show nodes_below
logging.info(nodes_below)
for k in nodes_below:
logging.info(streams[k].shape)
lg_channel = max(streams[k].shape[3], lg_channel)
for node_index in nodes_below:
attn = tf.reduce_mean(streams[node_index], [1, 2])
attn = tf.keras.layers.Dense(
units=lg_channel,
kernel_initializer=tf.random_normal_initializer(stddev=.01))(
inputs=attn)
peers.append(attn)
combined_inputs = fusion_with_peer_attention(
combined_inputs,
index=i,
attention_mode=attention_mode,
attention_in=peers,
use_5d_mode=False)
graph = asn.block_group(
inputs=combined_inputs,
filters=structure[i][2],
block_fn=block_fn,
blocks=num_blocks[block_number],
strides=structure[i][4],
name='block_group' + str(i),
block_level=structure[i][0],
num_frames=num_frames,
temporal_dilation=structure[i][3])
streams.append(graph)
if use_object_input:
inputs = [original_inputs, object_inputs]
else:
inputs = original_inputs
super(AssembleNetPlus, self).__init__(
inputs=inputs, outputs=streams, **kwargs)
@tf.keras.utils.register_keras_serializable(package='Vision')
class AssembleNetPlusModel(tf.keras.Model):
"""An AssembleNet++ model builder."""
def __init__(self,
backbone,
num_classes,
num_frames: int,
model_structure: List[Any],
input_specs: Optional[Dict[str,
tf.keras.layers.InputSpec]] = None,
max_pool_predictions: bool = False,
use_object_input: bool = False,
**kwargs):
if not input_specs:
input_specs = {
'image': layers.InputSpec(shape=[None, None, None, None, 3])
}
if use_object_input and 'object' not in input_specs:
input_specs['object'] = layers.InputSpec(shape=[None, None, None, None])
self._self_setattr_tracking = False
self._config_dict = {
'backbone': backbone,
'num_classes': num_classes,
'num_frames': num_frames,
'input_specs': input_specs,
'model_structure': model_structure,
}
self._input_specs = input_specs
self._backbone = backbone
grouping = {-3: [], -2: [], -1: [], 0: [], 1: [], 2: [], 3: []}
for i in range(len(model_structure)):
grouping[model_structure[i][0]].append(i)
inputs = {
k: tf.keras.Input(shape=v.shape[1:]) for k, v in input_specs.items()
}
if use_object_input:
streams = self._backbone(inputs=[inputs['image'], inputs['object']])
else:
streams = self._backbone(inputs=inputs['image'])
outputs = asn.multi_stream_heads(
streams,
grouping[3],
num_frames,
num_classes,
max_pool_predictions=max_pool_predictions)
super(AssembleNetPlusModel, self).__init__(
inputs=inputs, outputs=outputs, **kwargs)
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
return dict(backbone=self.backbone)
@property
def backbone(self):
return self._backbone
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
def assemblenet_plus(assemblenet_depth: int,
num_classes: int,
num_frames: int,
model_structure: List[Any],
input_specs: layers.InputSpec = layers.InputSpec(
shape=[None, None, None, None, 3]),
model_edge_weights: Optional[List[Any]] = None,
use_object_input: bool = False,
attention_mode: Optional[str] = None,
max_pool_predictions: bool = False,
**kwargs):
"""Returns the AssembleNet++ model for a given size and number of output classes."""
data_format = tf.keras.backend.image_data_format()
assert data_format == 'channels_last'
if assemblenet_depth not in asn.ASSEMBLENET_SPECS:
raise ValueError('Not a valid assemblenet_depth:', assemblenet_depth)
if use_object_input:
# assuming input_specs = [vide, obj] when use_object_input = True
input_specs_dict = {'image': input_specs[0], 'object': input_specs[1]}
else:
input_specs_dict = {'image': input_specs}
params = asn.ASSEMBLENET_SPECS[assemblenet_depth]
backbone = AssembleNetPlus(
block_fn=params['block'],
num_blocks=params['num_blocks'],
num_frames=num_frames,
model_structure=model_structure,
input_specs=input_specs,
model_edge_weights=model_edge_weights,
use_object_input=use_object_input,
attention_mode=attention_mode,
**kwargs)
return AssembleNetPlusModel(
backbone,
num_classes=num_classes,
num_frames=num_frames,
model_structure=model_structure,
input_specs=input_specs_dict,
use_object_input=use_object_input,
max_pool_predictions=max_pool_predictions,
**kwargs)
@backbone_factory.register_backbone_builder('assemblenet_plus')
def build_assemblenet_plus(
input_specs: tf.keras.layers.InputSpec,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
"""Builds assemblenet++ backbone."""
del l2_regularizer
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
assert backbone_type == 'assemblenet_plus'
assemblenet_depth = int(backbone_cfg.model_id)
if assemblenet_depth not in asn.ASSEMBLENET_SPECS:
raise ValueError('Not a valid assemblenet_depth:', assemblenet_depth)
model_structure, model_edge_weights = cfg.blocks_to_flat_lists(
backbone_cfg.blocks)
params = asn.ASSEMBLENET_SPECS[assemblenet_depth]
block_fn = functools.partial(
params['block'],
use_sync_bn=norm_activation_config.use_sync_bn,
bn_decay=norm_activation_config.norm_momentum,
bn_epsilon=norm_activation_config.norm_epsilon)
backbone = AssembleNetPlus(
block_fn=block_fn,
num_blocks=params['num_blocks'],
num_frames=backbone_cfg.num_frames,
model_structure=model_structure,
input_specs=input_specs,
model_edge_weights=model_edge_weights,
use_object_input=backbone_cfg.use_object_input,
attention_mode=backbone_cfg.attention_mode,
use_sync_bn=norm_activation_config.use_sync_bn,
bn_decay=norm_activation_config.norm_momentum,
bn_epsilon=norm_activation_config.norm_epsilon)
logging.info('Number of parameters in AssembleNet++ backbone: %f M.',
backbone.count_params() / 10.**6)
return backbone
@model_factory.register_model_builder('assemblenet_plus')
def build_assemblenet_plus_model(
input_specs: tf.keras.layers.InputSpec,
model_config: cfg.AssembleNetPlusModel,
num_classes: int,
l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None):
"""Builds assemblenet++ model."""
input_specs_dict = {'image': input_specs}
backbone = build_assemblenet_plus(input_specs, model_config.backbone,
model_config.norm_activation,
l2_regularizer)
backbone_cfg = model_config.backbone.get()
model_structure, _ = cfg.blocks_to_flat_lists(backbone_cfg.blocks)
model = AssembleNetPlusModel(
backbone,
num_classes=num_classes,
num_frames=backbone_cfg.num_frames,
model_structure=model_structure,
input_specs=input_specs_dict,
max_pool_predictions=model_config.max_pool_predictions,
use_object_input=backbone_cfg.use_object_input)
return model
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for assemblenet++ network."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.vision.beta.projects.assemblenet.configs import assemblenet as asn_config
from official.vision.beta.projects.assemblenet.modeling import assemblenet_plus as asnp
class AssembleNetPlusTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters((50, True, ''), (50, False, ''),
(50, False, 'peer'), (50, True, 'peer'),
(50, True, 'self'), (50, False, 'self'))
def test_network_creation(self, depth, use_object_input, attention_mode):
batch_size = 2
num_frames = 32
img_size = 64
num_classes = 101 # ufc-101
num_object_classes = 151 # 151 is for ADE-20k
if use_object_input:
vid_input = (batch_size * num_frames, img_size, img_size, 3)
obj_input = (batch_size * num_frames, img_size, img_size,
num_object_classes)
input_specs = (tf.keras.layers.InputSpec(shape=(vid_input)),
tf.keras.layers.InputSpec(shape=(obj_input)))
vid_inputs = np.random.rand(batch_size * num_frames, img_size, img_size,
3)
obj_inputs = np.random.rand(batch_size * num_frames, img_size, img_size,
num_object_classes)
inputs = [vid_inputs, obj_inputs]
# We are using the full_asnp50_structure, since we feed both video and
# object.
model_structure = asn_config.full_asnp50_structure # Uses object input.
edge_weights = asn_config.full_asnp_structure_weights
else:
# video input: (batch_size, FLAGS.num_frames, image_size, image_size, 3)
input_specs = tf.keras.layers.InputSpec(
shape=(batch_size, num_frames, img_size, img_size, 3))
inputs = np.random.rand(batch_size, num_frames, img_size, img_size, 3)
# Here, we are using model_structures.asn50_structure for AssembleNet++
# instead of full_asnp50_structure. By using asn50_structure, it
# essentially becomes AssembleNet++ without objects, only requiring RGB
# inputs (and optical flow to be computed inside the model).
model_structure = asn_config.asn50_structure
edge_weights = asn_config.asn_structure_weights
model = asnp.assemblenet_plus(
assemblenet_depth=depth,
num_classes=num_classes,
num_frames=num_frames,
model_structure=model_structure,
model_edge_weights=edge_weights,
input_specs=input_specs,
use_object_input=use_object_input,
attention_mode=attention_mode,
)
outputs = model(inputs)
self.assertAllEqual(outputs.shape.as_list(), [batch_size, num_classes])
if __name__ == '__main__':
tf.test.main()
...@@ -13,7 +13,16 @@ ...@@ -13,7 +13,16 @@
# limitations under the License. # limitations under the License.
# Lint as: python3 # Lint as: python3
"""Training driver.""" r"""Training driver.
Commandline:
python -m official.vision.beta.projects.assemblenet.trian \
--mode=train_and_eval --experiment=assemblenetplus_ucf101 \
--model_dir='YOUR MODEL SAVE GS BUCKET' \
--config_file=./official/vision/beta/projects/assemblenet/ \
--ucf101_assemblenet_plus_tpu.yaml \
--tpu=TPU_NAME
"""
from absl import app from absl import app
from absl import flags from absl import flags
...@@ -32,6 +41,7 @@ from official.modeling import performance ...@@ -32,6 +41,7 @@ from official.modeling import performance
# pylint: disable=unused-import # pylint: disable=unused-import
from official.vision.beta.projects.assemblenet.configs import assemblenet as asn_configs from official.vision.beta.projects.assemblenet.configs import assemblenet as asn_configs
from official.vision.beta.projects.assemblenet.modeling import assemblenet as asn from official.vision.beta.projects.assemblenet.modeling import assemblenet as asn
from official.vision.beta.projects.assemblenet.modeling import assemblenet_plus as asnp
# pylint: enable=unused-import # pylint: enable=unused-import
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -53,18 +63,35 @@ def main(_): ...@@ -53,18 +63,35 @@ def main(_):
f'{params.task.validation_data.feature_shape}') f'{params.task.validation_data.feature_shape}')
if 'assemblenet' in FLAGS.experiment: if 'assemblenet' in FLAGS.experiment:
if 'eval' in FLAGS.mode: if 'plus' in FLAGS.experiment:
# Use the feature shape in validation_data for all jobs. The number of if 'eval' in FLAGS.mode:
# frames in train_data will be used to construct the Assemblenet model. # Use the feature shape in validation_data for all jobs. The number of
params.task.model.backbone.assemblenet.num_frames = params.task.validation_data.feature_shape[ # frames in train_data will be used to construct the Assemblenet++
0] # model.
shape = params.task.validation_data.feature_shape params.task.model.backbone.assemblenet_plus.num_frames = (
params.task.validation_data.feature_shape[0])
shape = params.task.validation_data.feature_shape
else:
params.task.model.backbone.assemblenet_plus.num_frames = (
params.task.train_data.feature_shape[0])
shape = params.task.train_data.feature_shape
logging.info('mode %r num_frames %r feature shape %r', FLAGS.mode,
params.task.model.backbone.assemblenet_plus.num_frames,
shape)
else: else:
params.task.model.backbone.assemblenet.num_frames = params.task.train_data.feature_shape[ if 'eval' in FLAGS.mode:
0] # Use the feature shape in validation_data for all jobs. The number of
shape = params.task.train_data.feature_shape # frames in train_data will be used to construct the Assemblenet model.
logging.info('mode %r num_frames %r feature shape %r', FLAGS.mode, params.task.model.backbone.assemblenet.num_frames = (
params.task.model.backbone.assemblenet.num_frames, shape) params.task.validation_data.feature_shape[0])
shape = params.task.validation_data.feature_shape
else:
params.task.model.backbone.assemblenet.num_frames = (
params.task.train_data.feature_shape[0])
shape = params.task.train_data.feature_shape
logging.info('mode %r num_frames %r feature shape %r', FLAGS.mode,
params.task.model.backbone.assemblenet.num_frames, shape)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16' # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of # can have significant impact on model speeds by utilizing float16 in case of
...@@ -91,4 +118,5 @@ def main(_): ...@@ -91,4 +118,5 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
tfm_flags.define_flags() tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
app.run(main) app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment