Commit 063cef34 authored by SunJong Park's avatar SunJong Park Committed by A. Unique TensorFlower
Browse files

Copybara import of the project:

--
f544441ff132c14c0d95026181eb12c69b4eac1d by SunJong Park <53969182+ryan0507@users.noreply.github.com>:

Assemblenet++ Migration with TF2 (#10288)

* Assemblenet++ implementation with TF2 (UCF101 Dataset)

* pylint.sh passed

* train.py document updated

* README.md updated - AssembleNet and AssembleNet++

* YAML and configuration updated - AssembleNet and AssembleNet++

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/models/pull/10355 from tensorflow:assemblenet f544441ff132c14c0d95026181eb12c69b4eac1d
PiperOrigin-RevId: 414047541
parent 024ebd81
......@@ -2,6 +2,8 @@
This repository is the official implementations of the following papers.
The original implementations could be found in [here](https://github.com/google-research/google-research/tree/master/assemblenet)
[![Paper](http://img.shields.io/badge/Paper-arXiv.2008.03800-B3181B?logo=arXiv)](https://arxiv.org/abs/1905.13209)
[AssembleNet: Searching for Multi-Stream Neural Connectivity in Video
Architectures](https://arxiv.org/abs/1905.13209)
......@@ -10,5 +12,130 @@ Architectures](https://arxiv.org/abs/1905.13209)
[AssembleNet++: Assembling Modality Representations via Attention
Connections](https://arxiv.org/abs/2008.08072)
**DISCLAIMER**: AssembleNet++ implementation is still under development.
No support will be provided during the development phase.
DISCLAIMER: AssembleNet++ implementation is still under development. No support
will be provided during the development phase.
## Description
### AssembleNet vs. AssembleNet++
AssembleNet and AssembleNet++ both focus on neural connectivity search for
multi-stream video CNN architectures. They learn weights for the connections
between multiple convolutional blocks (composed of (2+1)D or 3D residual
modules) organized sequentially or in parallel, thereby optimizing the neural
architecture for the data/task.
AssembleNet++ adds *peer-attention* to the basic AssembleNet, which allows each
conv. block connection to be conditioned differently based on another block. It
is a form of channel-wise attention, which we found to be beneficial.
<img width="1158" alt="peer_attention" src="https://user-images.githubusercontent.com/53969182/135665233-e64ccda1-7dd3-45f2-9d77-5c4515703f13.png">
The code is provided in [assemblenet.py](modeling/assemblenet.py) and
[assemblenet_plus.py](modeling/assemblenet_plus.py). Notice that the provided
code uses (2+1)D residual modules as the building blocks of AssembleNet/++, but
you can use your own module while still benefitting from the connectivity search
of AssembleNet/++.
### Neural Architecture Search
As you will find from the [AssembleNet](https://arxiv.org/abs/1905.13209) paper,
the models we provide in [config files](configs/assemblenet.py) are the
result of architecture search/learning.
The architecture search in AssembleNet (and AssembleNet++) has two components:
(i) convolutional block configuration search using an evolutionary algorithm,
and (ii) one-shot differentiable connection search. We did not include the code
for the first part (i.e., evolution), as it relies on another infrastructure and
more computation. The 2nd part (i.e., differentiable search) is included in the
code however, which will allow you to use to code to search for the best
connectivity for your own models.
That is, as also described in the
[AssembleNet++](https://arxiv.org/abs/2008.08072) paper, once the convolutional
blocks are decided based on the search or manually, you can use the provide code
to obtain the best block connections and learn attention connectivity in a
one-shot differentiable way. You just need to train the network (with
`FLAGS.model_edge_weights` as `[]`) and the connectivity search will be done
simultaneously.
### AssembleNet and AssembleNet++ Structure Format
The format we use to specify AssembleNet/++ architectures is as follows: It is a
`list` corresponding to a graph representation of the network, where a node is a
convolutional block and an edge specifies a connection from one block to
another. Each node itself (in the structure list) is a `list` with the following
format: `[block_level, [list_of_input_blocks], number_filter, temporal_dilation,
spatial_stride]`. `[list_of_input_blocks]` should be the list of node indexes
whose values are less than the index of the node itself. The 'stems' of the
network directly taking raw inputs follow a different node format: `[stem_type,
temporal_dilation]`. The stem_type is -1 for RGB stem and is -2 for optical flow
stem. The stem_type -3 is reserved for the object segmentation input.
In AssembleNet++lite, instead of passing a single `int` for `number_filter`, we
pass a list/tuple of three `int`s. They specify the number of channels to be
used for each layer in the inverted bottleneck modules.
### Optical Flow and Data Loading
Instead of loading optical flows as inputs from data pipeline, we are applying
the
[Representation Flow](https://github.com/piergiaj/representation-flow-cvpr19) to
RGB frames so that we can compute the flow within TPU/GPU on fly. It's
essentially optical flow since it is computed directly from RGBs. The benefit is
that we don't need an external optical flow extraction and data loading. You
only need to feed RGB, and the flow will be computed internally.
## History
2021/10/02 : AssembleNet, AssembleNet++ implementation with UCF101 dataset
provided
## Authors
* SunJong Park ([@GitHub ryan0507](https://github.com/ryan0507))
* HyeYoon Lee ([@GitHub hylee817](https://github.com/hylee817))
## Table of Contents
* [AssembleNet vs AssembleNet++](#assemblenet-vs-assemblenet)
* [Neural Architecture Search](#neural-architecture-search)
* [AssembleNet and AssembleNet++ Structure Format](#assemblenet-and-assemblenet-structure-format)
* [Optical Flow and Data Loading](#optical-flow-and-data-loading)
## Requirements
[![TensorFlow 2.2](https://img.shields.io/badge/TensorFlow-2.5.0-FF6F00?logo=tensorflow)](https://github.com/tensorflow/tensorflow/releases/tag/v2.5.0)
[![Python 3.8](https://img.shields.io/badge/Python-3.8-3776AB)](https://www.python.org/downloads/release/python-380/)
## Training and Evaluation
Example of training AssembleNet with UCF101 TF Datasets.
```bash
python -m official.vision.beta.projects.assemblenet.trian \
--mode=train_and_eval --experiment=assemblenet_ucf101 \
--model_dir='YOUR_GS_BUCKET_TO_SAVE_MODEL' \
--config_file=./official/vision/beta/projects/assemblenet/\
--ucf101_assemblenet_tpu.yaml \
--tpu=TPU_NAME
```
Example of training AssembleNet++ with UCF101 TF Datasets.
```bash
python -m official.vision.beta.projects.assemblenet.trian \
--mode=train_and_eval --experiment=assemblenetplus_ucf101 \
--model_dir='YOUR_GS_BUCKET_TO_SAVE_MODEL' \
--config_file=./official/vision/beta/projects/assemblenet/\
--ucf101_assemblenet_plus_tpu.yaml \
--tpu=TPU_NAME
```
Currently, we provide experiments with kinetics400, kinetics500, kinetics600,
UCF101 datasets. If you want to add a new experiment you should modify
exp_factory for configuration.
......@@ -34,8 +34,9 @@ used for each layer in the inverted bottleneck modules.
The structure_weights specify the learned connection weights.
"""
from typing import List, Tuple
import dataclasses
from typing import List, Tuple, Optional
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
......@@ -176,26 +177,47 @@ class AssembleNet(hyperparams.Config):
blocks: Tuple[BlockSpec, ...] = tuple()
@dataclasses.dataclass
class AssembleNetPlus(hyperparams.Config):
model_id: str = '50'
num_frames: int = 0
attention_mode: str = 'None'
blocks: Tuple[BlockSpec, ...] = tuple()
use_object_input: bool = False
@dataclasses.dataclass
class Backbone3D(backbones_3d.Backbone3D):
"""Configuration for backbones.
Attributes:
type: 'str', type of backbone be used, on the of fields below.
resnet: resnet3d backbone config.
assemblenet: AssembleNet backbone config.
assemblenet_plus : AssembleNetPlus backbone config.
"""
type: str = 'assemblenet'
type: Optional[str] = None
assemblenet: AssembleNet = AssembleNet()
assemblenet_plus: AssembleNetPlus = AssembleNetPlus()
@dataclasses.dataclass
class AssembleNetModel(video_classification.VideoClassificationModel):
"""The AssembleNet model config."""
model_type: str = 'assemblenet'
backbone: Backbone3D = Backbone3D()
backbone: Backbone3D = Backbone3D(type='assemblenet')
norm_activation: common.NormActivation = common.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=True)
max_pool_preditions: bool = False
max_pool_predictions: bool = False
@dataclasses.dataclass
class AssembleNetPlusModel(video_classification.VideoClassificationModel):
"""The AssembleNet model config."""
model_type: str = 'assemblenet_plus'
backbone: Backbone3D = Backbone3D(type='assemblenet_plus')
norm_activation: common.NormActivation = common.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=True)
max_pool_predictions: bool = False
@exp_factory.register_config_factory('assemblenet50_kinetics600')
......@@ -223,3 +245,43 @@ def assemblenet_kinetics600() -> cfg.ExperimentConfig:
f'{exp.task.model.backbone.assemblenet}')
return exp
@exp_factory.register_config_factory('assemblenet_ucf101')
def assemblenet_ucf101() -> cfg.ExperimentConfig:
"""Video classification on Videonet with assemblenet."""
exp = video_classification.video_classification_ucf101()
exp.task.train_data.dtype = 'bfloat16'
exp.task.validation_data.dtype = 'bfloat16'
feature_shape = (32, 224, 224, 3)
model = AssembleNetModel()
model.backbone.assemblenet.blocks = flat_lists_to_blocks(
asn50_structure, asn_structure_weights)
model.backbone.assemblenet.num_frames = feature_shape[0]
exp.task.model = model
assert exp.task.model.backbone.assemblenet.num_frames > 0, (
f'backbone num_frames '
f'{exp.task.model.backbone.assemblenet}')
return exp
@exp_factory.register_config_factory('assemblenetplus_ucf101')
def assemblenetplus_ucf101() -> cfg.ExperimentConfig:
"""Video classification on Videonet with assemblenet."""
exp = video_classification.video_classification_ucf101()
exp.task.train_data.dtype = 'bfloat16'
exp.task.validation_data.dtype = 'bfloat16'
feature_shape = (32, 224, 224, 3)
model = AssembleNetPlusModel()
model.backbone.assemblenet_plus.blocks = flat_lists_to_blocks(
asn50_structure, asn_structure_weights)
model.backbone.assemblenet_plus.num_frames = feature_shape[0]
exp.task.model = model
assert exp.task.model.backbone.assemblenet_plus.num_frames > 0, (
f'backbone num_frames '
f'{exp.task.model.backbone.assemblenet_plus}')
return exp
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
from absl.testing import parameterized
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.vision.beta.configs import video_classification as exp_cfg
from official.vision.beta.projects.assemblenet.configs import assemblenet
class AssemblenetTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
('assemblenet50_kinetics600',),)
def test_assemblenet_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.VideoClassificationTask)
self.assertIsInstance(config.task.model, assemblenet.AssembleNetModel)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()
def test_configs_conversion(self):
blocks = assemblenet.flat_lists_to_blocks(assemblenet.asn50_structure,
assemblenet.asn_structure_weights)
re_structure, re_weights = assemblenet.blocks_to_flat_lists(blocks)
self.assertAllEqual(
re_structure, assemblenet.asn50_structure, msg='asn50_structure')
self.assertAllEqual(
re_weights,
assemblenet.asn_structure_weights,
msg='asn_structure_weights')
if __name__ == '__main__':
tf.test.main()
# Assemblenet++ structure video classificaion on UCF-101 dataset
# --experiment_type=assemblenetplus_ucf101
# device : TPU v3-8
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
model:
backbone:
assemblenet_plus:
model_id: 50
num_frames: 32
attention_mode: 'peer'
use_object_input: false
type: 'assemblenet_plus'
dropout_rate: 0.5
norm_activation:
activation: relu
norm_momentum: 0.99
norm_epsilon: 0.00001
use_sync_bn: true
max_pool_predictions: true
train_data:
is_training: true
global_batch_size: 64
dtype: 'bfloat16'
tfds_data_dir: 'gs://oss-yonsei/tensorflow_datasets/'
validation_data:
is_training: false
global_batch_size: 64
dtype: 'bfloat16'
tfds_data_dir: 'gs://oss-yonsei/tensorflow_datasets/'
drop_remainder: true
trainer:
train_steps: 900000 # 500 epochs
validation_steps: 144
validation_interval: 144
steps_per_loop: 144 # NUM_EXAMPLES (9537) // global_batch_size
summary_interval: 144
checkpoint_interval: 144
optimizer_config:
optimizer:
type: 'sgd'
sgd:
momentum: 0.9
learning_rate:
type: 'exponential'
exponential:
initial_learning_rate: 0.008 # 0.008 * batch_size / 128
decay_steps: 532 # 2.5 * steps_per_epoch
decay_rate: 0.96
staircase: true
warmup:
type: 'linear'
linear:
warmup_steps: 50
# Assemblenet structure video classificaion on UCF-101 dataset
# --experiment_type=assemblenet_ucf101
# device : TPU v3-8
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
model:
backbone:
assemblenet:
model_id: 101
num_frames: 32
combine_method: 'sigmoid'
type: 'assemblenet'
dropout_rate: 0.5
norm_activation:
activation: relu
norm_momentum: 0.99
norm_epsilon: 0.00001
use_sync_bn: true
max_pool_predictions: true
train_data:
is_training: true
global_batch_size: 32
dtype: 'bfloat16'
tfds_data_dir: 'gs://oss-yonsei/tensorflow_datasets/'
validation_data:
is_training: false
global_batch_size: 32
dtype: 'bfloat16'
tfds_data_dir: 'gs://oss-yonsei/tensorflow_datasets/'
drop_remainder: true
trainer:
train_steps: 90000 # 500 epochs
validation_steps: 288
validation_interval: 288
steps_per_loop: 288 # NUM_EXAMPLES (9537) // global_batch_size
summary_interval: 288
checkpoint_interval: 288
optimizer_config:
optimizer:
type: 'sgd'
sgd:
momentum: 0.9
learning_rate:
type: 'exponential'
exponential:
initial_learning_rate: 0.008 # 0.008 * batch_size / 128
decay_steps: 1024 # 2.5 * steps_per_epoch
decay_rate: 0.96
staircase: true
warmup:
type: 'linear'
linear:
warmup_steps: 50
......@@ -686,7 +686,7 @@ def multi_stream_heads(streams,
final_nodes,
num_frames,
num_classes,
max_pool_preditions: bool = False):
max_pool_predictions: bool = False):
"""Layers for the classification heads.
Args:
......@@ -694,7 +694,7 @@ def multi_stream_heads(streams,
final_nodes: A list of `int` where classification heads will be added.
num_frames: `int` number of frames in the input tensor.
num_classes: `int` number of possible classes for video classification.
max_pool_preditions: Use max-pooling on predictions instead of mean
max_pool_predictions: Use max-pooling on predictions instead of mean
pooling on features. It helps if you have more than 32 frames.
Returns:
......@@ -709,7 +709,7 @@ def multi_stream_heads(streams,
net = tf.identity(net, 'final_avg_pool0')
net = tf.reshape(net, [-1, num_frames, num_channels])
if not max_pool_preditions:
if not max_pool_predictions:
net = tf.reduce_mean(net, 1)
return net
......@@ -730,7 +730,7 @@ def multi_stream_heads(streams,
kernel_initializer=tf.random_normal_initializer(stddev=.01))(
inputs=outputs)
outputs = tf.identity(outputs, 'final_dense0')
if max_pool_preditions:
if max_pool_predictions:
pre_logits = outputs / np.sqrt(num_frames)
acts = tf.nn.softmax(pre_logits, axis=1)
outputs = tf.math.multiply(outputs, acts)
......@@ -894,7 +894,7 @@ class AssembleNetModel(tf.keras.Model):
model_structure: List[Any],
input_specs: Optional[Mapping[str,
tf.keras.layers.InputSpec]] = None,
max_pool_preditions: bool = False,
max_pool_predictions: bool = False,
**kwargs):
if not input_specs:
input_specs = {
......@@ -924,7 +924,7 @@ class AssembleNetModel(tf.keras.Model):
grouping[3],
num_frames,
num_classes,
max_pool_preditions=max_pool_preditions)
max_pool_predictions=max_pool_predictions)
super(AssembleNetModel, self).__init__(
inputs=inputs, outputs=outputs, **kwargs)
......@@ -981,7 +981,7 @@ def assemblenet_v1(assemblenet_depth: int,
input_specs: layers.InputSpec = layers.InputSpec(
shape=[None, None, None, None, 3]),
model_edge_weights: Optional[List[Any]] = None,
max_pool_preditions: bool = False,
max_pool_predictions: bool = False,
combine_method: str = 'sigmoid',
**kwargs):
"""Returns the AssembleNet model for a given size and number of output classes."""
......@@ -1009,7 +1009,7 @@ def assemblenet_v1(assemblenet_depth: int,
num_frames=num_frames,
model_structure=model_structure,
input_specs=input_specs_dict,
max_pool_preditions=max_pool_preditions,
max_pool_predictions=max_pool_predictions,
**kwargs)
......@@ -1025,7 +1025,7 @@ def build_assemblenet_v1(
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
assert backbone_type == 'assemblenet'
assert 'assemblenet' in backbone_type
assemblenet_depth = int(backbone_cfg.model_id)
if assemblenet_depth not in ASSEMBLENET_SPECS:
......@@ -1072,5 +1072,5 @@ def build_assemblenet_model(
num_frames=backbone_cfg.num_frames,
model_structure=model_structure,
input_specs=input_specs_dict,
max_pool_preditions=model_config.max_pool_preditions)
max_pool_predictions=model_config.max_pool_predictions)
return model
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for assemblenet++ network."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.vision.beta.projects.assemblenet.configs import assemblenet as asn_config
from official.vision.beta.projects.assemblenet.modeling import assemblenet_plus as asnp
class AssembleNetPlusTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters((50, True, ''), (50, False, ''),
(50, False, 'peer'), (50, True, 'peer'),
(50, True, 'self'), (50, False, 'self'))
def test_network_creation(self, depth, use_object_input, attention_mode):
batch_size = 2
num_frames = 32
img_size = 64
num_classes = 101 # ufc-101
num_object_classes = 151 # 151 is for ADE-20k
if use_object_input:
vid_input = (batch_size * num_frames, img_size, img_size, 3)
obj_input = (batch_size * num_frames, img_size, img_size,
num_object_classes)
input_specs = (tf.keras.layers.InputSpec(shape=(vid_input)),
tf.keras.layers.InputSpec(shape=(obj_input)))
vid_inputs = np.random.rand(batch_size * num_frames, img_size, img_size,
3)
obj_inputs = np.random.rand(batch_size * num_frames, img_size, img_size,
num_object_classes)
inputs = [vid_inputs, obj_inputs]
# We are using the full_asnp50_structure, since we feed both video and
# object.
model_structure = asn_config.full_asnp50_structure # Uses object input.
edge_weights = asn_config.full_asnp_structure_weights
else:
# video input: (batch_size, FLAGS.num_frames, image_size, image_size, 3)
input_specs = tf.keras.layers.InputSpec(
shape=(batch_size, num_frames, img_size, img_size, 3))
inputs = np.random.rand(batch_size, num_frames, img_size, img_size, 3)
# Here, we are using model_structures.asn50_structure for AssembleNet++
# instead of full_asnp50_structure. By using asn50_structure, it
# essentially becomes AssembleNet++ without objects, only requiring RGB
# inputs (and optical flow to be computed inside the model).
model_structure = asn_config.asn50_structure
edge_weights = asn_config.asn_structure_weights
model = asnp.assemblenet_plus(
assemblenet_depth=depth,
num_classes=num_classes,
num_frames=num_frames,
model_structure=model_structure,
model_edge_weights=edge_weights,
input_specs=input_specs,
use_object_input=use_object_input,
attention_mode=attention_mode,
)
outputs = model(inputs)
self.assertAllEqual(outputs.shape.as_list(), [batch_size, num_classes])
if __name__ == '__main__':
tf.test.main()
......@@ -13,7 +13,16 @@
# limitations under the License.
# Lint as: python3
"""Training driver."""
r"""Training driver.
Commandline:
python -m official.vision.beta.projects.assemblenet.trian \
--mode=train_and_eval --experiment=assemblenetplus_ucf101 \
--model_dir='YOUR MODEL SAVE GS BUCKET' \
--config_file=./official/vision/beta/projects/assemblenet/ \
--ucf101_assemblenet_plus_tpu.yaml \
--tpu=TPU_NAME
"""
from absl import app
from absl import flags
......@@ -32,6 +41,7 @@ from official.modeling import performance
# pylint: disable=unused-import
from official.vision.beta.projects.assemblenet.configs import assemblenet as asn_configs
from official.vision.beta.projects.assemblenet.modeling import assemblenet as asn
from official.vision.beta.projects.assemblenet.modeling import assemblenet_plus as asnp
# pylint: enable=unused-import
FLAGS = flags.FLAGS
......@@ -53,18 +63,35 @@ def main(_):
f'{params.task.validation_data.feature_shape}')
if 'assemblenet' in FLAGS.experiment:
if 'eval' in FLAGS.mode:
# Use the feature shape in validation_data for all jobs. The number of
# frames in train_data will be used to construct the Assemblenet model.
params.task.model.backbone.assemblenet.num_frames = params.task.validation_data.feature_shape[
0]
shape = params.task.validation_data.feature_shape
if 'plus' in FLAGS.experiment:
if 'eval' in FLAGS.mode:
# Use the feature shape in validation_data for all jobs. The number of
# frames in train_data will be used to construct the Assemblenet++
# model.
params.task.model.backbone.assemblenet_plus.num_frames = (
params.task.validation_data.feature_shape[0])
shape = params.task.validation_data.feature_shape
else:
params.task.model.backbone.assemblenet_plus.num_frames = (
params.task.train_data.feature_shape[0])
shape = params.task.train_data.feature_shape
logging.info('mode %r num_frames %r feature shape %r', FLAGS.mode,
params.task.model.backbone.assemblenet_plus.num_frames,
shape)
else:
params.task.model.backbone.assemblenet.num_frames = params.task.train_data.feature_shape[
0]
shape = params.task.train_data.feature_shape
logging.info('mode %r num_frames %r feature shape %r', FLAGS.mode,
params.task.model.backbone.assemblenet.num_frames, shape)
if 'eval' in FLAGS.mode:
# Use the feature shape in validation_data for all jobs. The number of
# frames in train_data will be used to construct the Assemblenet model.
params.task.model.backbone.assemblenet.num_frames = (
params.task.validation_data.feature_shape[0])
shape = params.task.validation_data.feature_shape
else:
params.task.model.backbone.assemblenet.num_frames = (
params.task.train_data.feature_shape[0])
shape = params.task.train_data.feature_shape
logging.info('mode %r num_frames %r feature shape %r', FLAGS.mode,
params.task.model.backbone.assemblenet.num_frames, shape)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
......@@ -91,4 +118,5 @@ def main(_):
if __name__ == '__main__':
tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment