Commit c609ff2e authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 369249071
parent 56cda9c5
# AssembleNet and AssembleNet++
This repository is the official implementations of the following papers.
[![Paper](http://img.shields.io/badge/Paper-arXiv.2008.03800-B3181B?logo=arXiv)](https://arxiv.org/abs/1905.13209)
[AssembleNet: Searching for Multi-Stream Neural Connectivity in Video
Architectures](https://arxiv.org/abs/1905.13209)
[![Paper](http://img.shields.io/badge/Paper-arXiv.2008.08072-B3181B?logo=arXiv)](https://arxiv.org/abs/1905.13209)
[AssembleNet++: Assembling Modality Representations via Attention
Connections](https://arxiv.org/abs/2008.08072)
**DISCLAIMER**: AssembleNet++ implementation is still under development.
No support will be provided during the development phase.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Definitions for AssembleNet/++ structures.
This structure is a `list` corresponding to a graph representation of the
network, where a node is a convolutional block and an edge specifies a
connection from one block to another.
Each node itself (in the structure list) is a list with the following format:
[block_level, [list_of_input_blocks], number_filter, temporal_dilation,
spatial_stride]. [list_of_input_blocks] should be the list of node indexes whose
values are less than the index of the node itself. The 'stems' of the network
directly taking raw inputs follow a different node format:
[stem_type, temporal_dilation]. The stem_type is -1 for RGB stem and is -2 for
optical flow stem. The stem_type -3 is reserved for the object segmentation
input.
In AssembleNet++lite, instead of passing a single `int` for number_filter, we
pass a list/tuple of three `int`s. They specify the number of channels to be
used for each layer in the inverted bottleneck modules.
The structure_weights specify the learned connection weights.
"""
from typing import List, Tuple
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision.beta.configs import backbones_3d
from official.vision.beta.configs import common
from official.vision.beta.configs.google import video_classification
@dataclasses.dataclass
class BlockSpec(hyperparams.Config):
level: int = -1
input_blocks: Tuple[int, ...] = tuple()
num_filters: int = -1
temporal_dilation: int = 1
spatial_stride: int = 1
input_block_weight: Tuple[float, ...] = tuple()
def flat_lists_to_blocks(model_structures, model_edge_weights):
"""Transforms the raw list structure configs to BlockSpec tuple."""
blocks = []
for node, edge_weights in zip(model_structures, model_edge_weights):
if node[0] < 0:
block = BlockSpec(level=node[0], temporal_dilation=node[1])
else:
block = BlockSpec(
level=node[0],
input_blocks=node[1],
num_filters=node[2],
temporal_dilation=node[3],
spatial_stride=node[4])
if edge_weights:
assert len(edge_weights[0]) == len(block.input_blocks), (
f'{len(edge_weights[0])} != {len(block.input_blocks)} at block '
f'{block} weight {edge_weights}')
block.input_block_weight = tuple(edge_weights[0])
blocks.append(block)
return tuple(blocks)
def blocks_to_flat_lists(blocks: List[BlockSpec]):
"""Transforms BlockSpec tuple to the raw list structure configs."""
# pylint: disable=g-complex-comprehension
# pylint: disable=g-long-ternary
model_structure = [[
b.level,
list(b.input_blocks), b.num_filters, b.temporal_dilation,
b.spatial_stride, 0
] if b.level >= 0 else [b.level, b.temporal_dilation] for b in blocks]
model_edge_weights = [
[list(b.input_block_weight)] if b.input_block_weight else []
for b in blocks
]
return model_structure, model_edge_weights
# AssembleNet structure for 50/101 layer models, found using evolution with the
# Moments-in-Time dataset. This is the structure used for the experiments in the
# AssembleNet paper. The learned connectivity weights are also provided.
asn50_structure = [[-1, 4], [-1, 4], [-2, 1], [-2, 1], [0, [1], 32, 1, 1, 0],
[0, [0], 32, 4, 1, 0], [0, [0, 1, 2, 3], 32, 1, 1, 0],
[0, [2, 3], 32, 2, 1, 0], [1, [0, 4, 5, 6, 7], 64, 2, 2, 0],
[1, [0, 2, 4, 7], 64, 1, 2, 0], [1, [0, 5, 7], 64, 4, 2, 0],
[1, [0, 5], 64, 1, 2, 0], [2, [4, 8, 10, 11], 256, 1, 2, 0],
[2, [8, 9], 256, 4, 2, 0], [3, [12, 13], 512, 2, 2, 0]]
asn101_structure = [[-1, 4], [-1, 4], [-2, 1], [-2, 1], [0, [1], 32, 1, 1, 0],
[0, [0], 32, 4, 1, 0], [0, [0, 1, 2, 3], 32, 1, 1, 0],
[0, [2, 3], 32, 2, 1, 0], [1, [0, 4, 5, 6, 7], 64, 2, 2, 0],
[1, [0, 2, 4, 7], 64, 1, 2, 0], [1, [0, 5, 7], 64, 4, 2, 0],
[1, [0, 5], 64, 1, 2, 0], [2, [4, 8, 10, 11], 192, 1, 2, 0],
[2, [8, 9], 192, 4, 2, 0], [3, [12, 13], 256, 2, 2, 0]]
asn_structure_weights = [
[], [], [], [], [], [],
[[
0.13810564577579498, 0.8465337157249451, 0.3072969317436218,
0.2867436408996582
]], [[0.5846117734909058, 0.6066334843635559]],
[[
0.16382087767124176, 0.8852924704551697, 0.4039595425128937,
0.6823437809944153, 0.5331538319587708
]],
[[
0.028569204732775688, 0.10333596915006638, 0.7517264485359192,
0.9260114431381226
]], [[0.28832191228866577, 0.7627848982810974, 0.404977947473526]],
[[0.23474831879138947, 0.7841425538063049]],
[[
0.27616503834724426, 0.9514784812927246, 0.6568767428398132,
0.9547983407974243
]], [[0.5047007203102112, 0.8876819610595703]],
[[0.9892204403877258, 0.8454614877700806]]
]
# AssembleNet++ structure for 50 layer models, found with the Charades dataset.
# This is the model used in the experiments in the AssembleNet++ paper.
# Note that, in order the build AssembleNet++ with this structure, you also need
# to feed 'object segmentation input' to the network indicated as [-3, 4]. It's
# the 5th block in the architecture.
# If you don't plan to use the object input but want to still benefit from
# peer-attention in AssembleNet++ (with RGB and OF), please use the above
# AssembleNet-50 model instead with assemblenet_plus.py code.
full_asnp50_structure = [[-1, 2], [-1, 4], [-2, 2], [-2, 1], [-3, 4],
[0, [0, 1, 2, 3, 4], 32, 1, 1, 0],
[0, [0, 1, 4], 32, 4, 1, 0],
[0, [2, 3, 4], 32, 8, 1, 0],
[0, [2, 3, 4], 32, 1, 1, 0],
[1, [0, 1, 2, 4, 5, 6, 7, 8], 64, 4, 2, 0],
[1, [2, 3, 4, 7, 8], 64, 1, 2, 0],
[1, [0, 4, 5, 6, 7], 128, 8, 2, 0],
[2, [4, 11], 256, 8, 2, 0],
[2, [2, 3, 4, 5, 6, 7, 8, 10, 11], 256, 4, 2, 0],
[3, [12, 13], 512, 2, 2, 0]]
full_asnp_structure_weights = [[], [], [], [], [], [[0.6143830418586731, 0.7111759185791016, 0.19351491332054138, 0.1701001077890396, 0.7178536653518677]], [[0.5755624771118164, 0.5644599795341492, 0.7128658294677734]], [[0.26563042402267456, 0.3033692538738251, 0.8244096636772156]], [[0.07013848423957825, 0.07905343919992447, 0.8767927885055542]], [[0.5008697509765625, 0.5020178556442261, 0.49819135665893555, 0.5015180706977844, 0.4987695813179016, 0.4990265369415283, 0.499239057302475, 0.4974501430988312]], [[0.47034338116645813, 0.4694305658340454, 0.767791748046875, 0.5539310574531555, 0.4520096182823181]], [[0.2769702076911926, 0.8116549253463745, 0.597356915473938, 0.6585626602172852, 0.5915306210517883]], [[0.501274824142456, 0.5016682147979736]], [[0.0866393893957138, 0.08469288796186447, 0.9739039540290833, 0.058271341025829315, 0.08397126197814941, 0.10285478830337524, 0.18506969511508942, 0.23874442279338837, 0.9188644886016846]], [[0.4174623489379883, 0.5844835638999939]]] # pylint: disable=line-too-long
# AssembleNet++lite structure using inverted bottleneck blocks. By specifing
# the connection weights as [], the model could alos automatically learn the
# connection weights during its training.
asnp_lite_structure = [[-1, 1], [-2, 1],
[0, [0, 1], [27, 27, 12], 1, 2, 0],
[0, [0, 1], [27, 27, 12], 4, 2, 0],
[1, [0, 1, 2, 3], [54, 54, 24], 2, 2, 0],
[1, [0, 1, 2, 3], [54, 54, 24], 1, 2, 0],
[1, [0, 1, 2, 3], [54, 54, 24], 4, 2, 0],
[1, [0, 1, 2, 3], [54, 54, 24], 1, 2, 0],
[2, [0, 1, 2, 3, 4, 5, 6, 7], [152, 152, 68], 1, 2, 0],
[2, [0, 1, 2, 3, 4, 5, 6, 7], [152, 152, 68], 4, 2, 0],
[3, [2, 3, 4, 5, 6, 7, 8, 9], [432, 432, 192], 2, 2, 0]]
asnp_lite_structure_weights = [[], [], [[0.19914183020591736, 0.9278576374053955]], [[0.010816320776939392, 0.888792097568512]], [[0.9473835825920105, 0.6303419470787048, 0.1704932451248169, 0.05950307101011276]], [[0.9560931324958801, 0.7898273468017578, 0.36138781905174255, 0.07344610244035721]], [[0.9213919043540955, 0.13418640196323395, 0.8371981978416443, 0.07936054468154907]], [[0.9441559910774231, 0.9435100555419922, 0.7253988981246948, 0.13498817384243011]], [[0.9964852333068848, 0.8427878618240356, 0.8895476460456848, 0.11014710366725922, 0.6270533204078674, 0.44782018661499023, 0.61344975233078, 0.44898226857185364]], [[0.9970942735671997, 0.7105681896209717, 0.5078442096710205, 0.0951600968837738, 0.624282717704773, 0.8527252674102783, 0.8105692863464355, 0.7857823967933655]], [[0.6180334091186523, 0.11882413923740387, 0.06102970987558365, 0.04484326392412186, 0.05602221190929413, 0.052324872463941574, 0.9969874024391174, 0.9987731575965881]]] # pylint: disable=line-too-long
@dataclasses.dataclass
class AssembleNet(hyperparams.Config):
model_id: str = '50'
num_frames: int = 0
combine_method: str = 'sigmoid'
blocks: Tuple[BlockSpec, ...] = tuple()
@dataclasses.dataclass
class Backbone3D(backbones_3d.Backbone3D):
"""Configuration for backbones.
Attributes:
type: 'str', type of backbone be used, on the of fields below.
resnet: resnet3d backbone config.
"""
type: str = 'assemblenet'
assemblenet: AssembleNet = AssembleNet()
@dataclasses.dataclass
class AssembleNetModel(video_classification.VideoClassificationModel):
"""The AssembleNet model config."""
model_type: str = 'assemblenet'
backbone: Backbone3D = Backbone3D()
norm_activation: common.NormActivation = common.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=True)
max_pool_preditions: bool = False
@exp_factory.register_config_factory('assemblenet50_kinetics600')
def assemblenet_kinetics600() -> cfg.ExperimentConfig:
"""Video classification on Videonet with assemblenet."""
exp = video_classification.video_classification_kinetics600()
feature_shape = (32, 224, 224, 3)
exp.task.train_data.global_batch_size = 1024
exp.task.validation_data.global_batch_size = 32
exp.task.train_data.feature_shape = feature_shape
exp.task.validation_data.feature_shape = (120, 224, 224, 3)
exp.task.train_data.dtype = 'bfloat16'
exp.task.validation_data.dtype = 'bfloat16'
model = AssembleNetModel()
model.backbone.assemblenet.model_id = '50'
model.backbone.assemblenet.blocks = flat_lists_to_blocks(
asn50_structure, asn_structure_weights)
model.backbone.assemblenet.num_frames = feature_shape[0]
exp.task.model = model
assert exp.task.model.backbone.assemblenet.num_frames > 0, (
f'backbone num_frames '
f'{exp.task.model.backbone.assemblenet}')
return exp
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Contains definitions for 'Representation Flow' layer [1].
Representation flow layer is a generalization of optical flow extraction; the
layer could be inserted anywhere within a CNN to capture feature movements. This
is the version taking 4D tensor with the shape [batch*time, height, width,
channels], to make this run on TPU.
[1] AJ Piergiovanni and Michael S. Ryoo,
Representation Flow for Action Recognition. CVPR 2019.
"""
import numpy as np
import tensorflow as tf
layers = tf.keras.layers
BATCH_NORM_DECAY = 0.99
BATCH_NORM_EPSILON = 1e-5
def build_batch_norm(init_zero: bool = False,
bn_decay: float = BATCH_NORM_DECAY,
bn_epsilon: float = BATCH_NORM_EPSILON,
use_sync_bn: bool = False):
"""Performs a batch normalization followed by a ReLU.
Args:
init_zero: `bool` if True, initializes scale parameter of batch
normalization with 0 instead of 1 (default).
bn_decay: `float` batch norm decay parameter to use.
bn_epsilon: `float` batch norm epsilon parameter to use.
use_sync_bn: use synchronized batch norm for TPU.
Returns:
A normalized `Tensor` with the same `data_format`.
"""
if init_zero:
gamma_initializer = tf.zeros_initializer()
else:
gamma_initializer = tf.ones_initializer()
data_format = tf.keras.backend.image_data_format()
assert data_format == 'channels_last'
if data_format == 'channels_first':
axis = 1
else:
axis = -1
if use_sync_bn:
batch_norm = layers.experimental.SyncBatchNormalization(
axis=axis,
momentum=bn_decay,
epsilon=bn_epsilon,
gamma_initializer=gamma_initializer)
else:
batch_norm = layers.BatchNormalization(
axis=axis,
momentum=bn_decay,
epsilon=bn_epsilon,
fused=True,
gamma_initializer=gamma_initializer)
return batch_norm
def divergence(p1, p2, f_grad_x, f_grad_y, name):
"""Computes the divergence value used with TV-L1 optical flow algorithm.
Args:
p1: 'Tensor' input.
p2: 'Tensor' input in the next frame.
f_grad_x: 'Tensor' x gradient of F value used in TV-L1.
f_grad_y: 'Tensor' y gradient of F value used in TV-L1.
name: 'str' name for the variable scope.
Returns:
A `Tensor` with the same `data_format` and shape as input.
"""
data_format = tf.keras.backend.image_data_format()
df = 'NHWC' if data_format == 'channels_last' else 'NCHW'
with tf.name_scope('divergence_' + name):
if data_format == 'channels_last':
p1 = tf.pad(p1[:, :, :-1, :], [[0, 0], [0, 0], [1, 0], [0, 0]])
p2 = tf.pad(p2[:, :-1, :, :], [[0, 0], [1, 0], [0, 0], [0, 0]])
else:
p1 = tf.pad(p1[:, :, :, :-1], [[0, 0], [0, 0], [0, 0], [1, 0]])
p2 = tf.pad(p2[:, :, :-1, :], [[0, 0], [0, 0], [1, 0], [0, 0]])
grad_x = tf.nn.conv2d(p1, f_grad_x, [1, 1, 1, 1], 'SAME', data_format=df)
grad_y = tf.nn.conv2d(p2, f_grad_y, [1, 1, 1, 1], 'SAME', data_format=df)
return grad_x + grad_y
def forward_grad(x, f_grad_x, f_grad_y, name):
data_format = tf.keras.backend.image_data_format()
with tf.name_scope('forward_grad_' + name):
df = 'NHWC' if data_format == 'channels_last' else 'NCHW'
grad_x = tf.nn.conv2d(x, f_grad_x, [1, 1, 1, 1], 'SAME', data_format=df)
grad_y = tf.nn.conv2d(x, f_grad_y, [1, 1, 1, 1], 'SAME', data_format=df)
return grad_x, grad_y
def norm_img(x):
mx = tf.reduce_max(x)
mn = tf.reduce_min(x)
if mx == mn:
return x
else:
return 255 * (x - mn) / (mx - mn)
class RepresentationFlow(layers.Layer):
"""Computes the representation flow motivated by TV-L1 optical flow."""
def __init__(self,
time: int,
depth: int,
num_iter: int = 20,
bottleneck: int = 32,
train_feature_grad: bool = False,
train_divergence: bool = False,
train_flow_grad: bool = False,
train_hyper: bool = False,
**kwargs):
"""Constructor.
Args:
time: 'int' number of frames in the input tensor.
depth: channel depth of the input tensor.
num_iter: 'int' number of iterations to use for the flow computation.
bottleneck: 'int' number of filters to be used for the flow computation.
train_feature_grad: Train image grad params.
train_divergence: train divergence params
train_flow_grad: train flow grad params.
train_hyper: train rep flow hyperparams.
**kwargs: keyword arguments to be passed to the parent constructor.
Returns:
A `Tensor` with the same `data_format` and shape as input.
"""
super(RepresentationFlow, self).__init__(**kwargs)
self._time = time
self._depth = depth
self._num_iter = num_iter
self._bottleneck = bottleneck
self._train_feature_grad = train_feature_grad
self._train_divergence = train_divergence
self._train_flow_grad = train_flow_grad
self._train_hyper = train_hyper
def get_config(self):
config = {
'time': self._time,
'num_iter': self._num_iter,
'bottleneck': self._bottleneck,
'train_feature_grad': self._train_feature_grad,
'train_divergence': self._train_divergence,
'train_flow_grad': self._train_flow_grad,
'train_hyper': self._train_hyper,
}
base_config = super(RepresentationFlow, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape: tf.TensorShape):
img_grad = np.array([-0.5, 0, 0.5], dtype='float32')
img_grad_x = np.repeat(
np.reshape(img_grad, (1, 3, 1, 1)), self._bottleneck, axis=2) * np.eye(
self._bottleneck, dtype='float32')
self.img_grad_x = self.add_weight(
shape=img_grad_x.shape,
initializer=tf.constant_initializer(img_grad_x),
trainable=self._train_feature_grad,
name='img_grad_x')
img_grad_y = np.repeat(
np.reshape(img_grad, (3, 1, 1, 1)), self._bottleneck, axis=2) * np.eye(
self._bottleneck, dtype='float32')
self.img_grad_y = self.add_weight(
shape=img_grad_y.shape,
initializer=tf.constant_initializer(img_grad_y),
trainable=self._train_feature_grad,
name='img_grad_y')
f_grad = np.array([-1, 1], dtype='float32')
f_grad_x = np.repeat(
np.reshape(f_grad, (1, 2, 1, 1)), self._bottleneck, axis=2) * np.eye(
self._bottleneck, dtype='float32')
self.f_grad_x = self.add_weight(
shape=f_grad_x.shape,
initializer=tf.constant_initializer(f_grad_x),
trainable=self._train_divergence,
name='f_grad_x')
f_grad_y = np.repeat(
np.reshape(f_grad, (2, 1, 1, 1)), self._bottleneck, axis=2) * np.eye(
self._bottleneck, dtype='float32')
self.f_grad_y = self.add_weight(
shape=f_grad_y.shape,
initializer=tf.constant_initializer(f_grad_y),
trainable=self._train_divergence,
name='f_grad_y')
f_grad_x2 = np.repeat(
np.reshape(f_grad, (1, 2, 1, 1)), self._bottleneck, axis=2) * np.eye(
self._bottleneck, dtype='float32')
self.f_grad_x2 = self.add_weight(
shape=f_grad_x2.shape,
initializer=tf.constant_initializer(f_grad_x2),
trainable=self._train_flow_grad,
name='f_grad_x2')
f_grad_y2 = np.repeat(
np.reshape(f_grad, (2, 1, 1, 1)), self._bottleneck, axis=2) * np.eye(
self._bottleneck, dtype='float32')
self.f_grad_y2 = self.add_weight(
shape=f_grad_y2.shape,
initializer=tf.constant_initializer(f_grad_y2),
trainable=self._train_flow_grad,
name='f_grad_y2')
self.t = self.add_weight(
name='theta',
initializer=tf.constant_initializer(0.3),
trainable=self._train_hyper)
self.l = self.add_weight(
name='lambda',
initializer=tf.constant_initializer(0.15),
trainable=self._train_hyper)
self.a = self.add_weight(
name='tau',
initializer=tf.constant_initializer(0.25),
trainable=self._train_hyper)
self.t = tf.abs(self.t) + 1e-12
self.l_t = self.l * self.t
self.taut = self.a / self.t
self._bottleneck_conv2 = None
self._bottleneck_conv2 = None
if self._bottleneck > 1:
self._bottleneck_conv1 = layers.Conv2D(
filters=self._bottleneck,
kernel_size=1,
strides=1,
padding='same',
use_bias=False,
kernel_initializer=tf.keras.initializers.VarianceScaling(),
name='rf/bottleneck1')
self._bottleneck_conv2 = layers.Conv2D(
filters=self._depth,
kernel_size=1,
strides=1,
padding='same',
use_bias=False,
kernel_initializer=tf.keras.initializers.VarianceScaling(),
name='rf/bottleneck2')
self._batch_norm = build_batch_norm(init_zero=True)
def call(self, inputs: tf.Tensor, training: bool = None) -> tf.Tensor:
"""Perform representation flows.
Args:
inputs: list of `Tensors` of shape `[batch*time, height, width,
channels]`.
training: True for training phase.
Returns:
A tensor of the same shape as the inputs.
"""
data_format = tf.keras.backend.image_data_format()
df = 'NHWC' if data_format == 'channels_last' else 'NCHW'
axis = 3 if data_format == 'channels_last' else 1 # channel axis
dtype = inputs.dtype
residual = inputs
depth = inputs.shape.as_list()[axis]
# assert depth == self._depth, f'rep_flow {depth} != {self._depth}'
if self._bottleneck == 1:
inputs = tf.reduce_mean(inputs, axis=axis)
inputs = tf.expand_dims(inputs, -1)
elif depth != self._bottleneck:
inputs = self._bottleneck_conv1(inputs)
input_shape = inputs.shape.as_list()
inp = norm_img(inputs)
inp = tf.reshape(
inp,
(-1, self._time, inputs.shape[1], inputs.shape[2], inputs.shape[3]))
inp = tf.ensure_shape(
inp, (None, self._time, input_shape[1], input_shape[2], input_shape[3]))
img1 = tf.reshape(
inp[:, :-1], (-1, tf.shape(inp)[2], tf.shape(inp)[3], tf.shape(inp)[4]))
img2 = tf.reshape(
inp[:, 1:], (-1, tf.shape(inp)[2], tf.shape(inp)[3], tf.shape(inp)[4]))
img1 = tf.ensure_shape(
img1, (None, inputs.shape[1], inputs.shape[2], inputs.shape[3]))
img2 = tf.ensure_shape(
img2, (None, inputs.shape[1], inputs.shape[2], inputs.shape[3]))
u1 = tf.zeros_like(img1, dtype=dtype)
u2 = tf.zeros_like(img2, dtype=dtype)
l_t = self.l_t
taut = self.taut
grad2_x = tf.nn.conv2d(
img2, self.img_grad_x, [1, 1, 1, 1], 'SAME', data_format=df)
grad2_y = tf.nn.conv2d(
img2, self.img_grad_y, [1, 1, 1, 1], 'SAME', data_format=df)
p11 = tf.zeros_like(img1, dtype=dtype)
p12 = tf.zeros_like(img1, dtype=dtype)
p21 = tf.zeros_like(img1, dtype=dtype)
p22 = tf.zeros_like(img1, dtype=dtype)
gsqx = grad2_x**2
gsqy = grad2_y**2
grad = gsqx + gsqy + 1e-12
rho_c = img2 - grad2_x * u1 - grad2_y * u2 - img1
for _ in range(self._num_iter):
rho = rho_c + grad2_x * u1 + grad2_y * u2 + 1e-12
v1 = tf.zeros_like(img1, dtype=dtype)
v2 = tf.zeros_like(img2, dtype=dtype)
mask1 = rho < -l_t * grad
tmp11 = tf.where(mask1, l_t * grad2_x,
tf.zeros_like(grad2_x, dtype=dtype))
tmp12 = tf.where(mask1, l_t * grad2_y,
tf.zeros_like(grad2_y, dtype=dtype))
mask2 = rho > l_t * grad
tmp21 = tf.where(mask2, -l_t * grad2_x,
tf.zeros_like(grad2_x, dtype=dtype))
tmp22 = tf.where(mask2, -l_t * grad2_y,
tf.zeros_like(grad2_y, dtype=dtype))
mask3 = (~mask1) & (~mask2) & (grad > 1e-12)
tmp31 = tf.where(mask3, (-rho / grad) * grad2_x,
tf.zeros_like(grad2_x, dtype=dtype))
tmp32 = tf.where(mask3, (-rho / grad) * grad2_y,
tf.zeros_like(grad2_y, dtype=dtype))
v1 = tmp11 + tmp21 + tmp31 + u1
v2 = tmp12 + tmp22 + tmp32 + u2
u1 = v1 + self.t * divergence(p11, p12, self.f_grad_x, self.f_grad_y,
'div_p1')
u2 = v2 + self.t * divergence(p21, p22, self.f_grad_x, self.f_grad_y,
'div_p2')
u1x, u1y = forward_grad(u1, self.f_grad_x2, self.f_grad_y2, 'u1')
u2x, u2y = forward_grad(u2, self.f_grad_x2, self.f_grad_y2, 'u2')
p11 = (p11 + taut * u1x) / (1. + taut * tf.sqrt(u1x**2 + u1y**2 + 1e-12))
p12 = (p12 + taut * u1y) / (1. + taut * tf.sqrt(u1x**2 + u1y**2 + 1e-12))
p21 = (p21 + taut * u2x) / (1. + taut * tf.sqrt(u2x**2 + u2y**2 + 1e-12))
p22 = (p22 + taut * u2y) / (1. + taut * tf.sqrt(u2x**2 + u2y**2 + 1e-12))
u1 = tf.reshape(u1, (-1, self._time - 1, tf.shape(u1)[1],
tf.shape(u1)[2], tf.shape(u1)[3]))
u2 = tf.reshape(u2, (-1, self._time - 1, tf.shape(u2)[1],
tf.shape(u2)[2], tf.shape(u2)[3]))
flow = tf.concat([u1, u2], axis=axis + 1)
flow = tf.concat([
flow,
tf.reshape(
flow[:, -1, :, :, :],
(-1, 1, tf.shape(u1)[2], tf.shape(u1)[3], tf.shape(u1)[4] * 2))
],
axis=1)
# padding: [bs, 1, w, h, 2*c] -> [bs, 1, w, h, 2*c]
# flow is [bs, t, w, h, 2*c]
flow = tf.reshape(
flow, (-1, tf.shape(u1)[2], tf.shape(u2)[3], tf.shape(u1)[4] * 2))
# folwo is [bs*t, w, h, 2*c]
if self._bottleneck == 1:
output_shape = residual.shape.as_list()
output_shape[-1] = self._bottleneck * 2
flow = tf.ensure_shape(flow, output_shape)
return flow
else:
flow = self._bottleneck_conv2(flow)
flow = self._batch_norm(flow)
flow = tf.ensure_shape(flow, residual.shape)
return tf.nn.relu(flow + residual)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Training driver."""
from absl import app
from absl import flags
from absl import logging
import gin
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
# pylint: disable=unused-import
from official.vision.beta.projects.assemblenet.configs import assemblenet as asn_configs
from official.vision.beta.projects.assemblenet.modeling import assemblenet as asn
# pylint: enable=unused-import
FLAGS = flags.FLAGS
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
if 'train_and_eval' in FLAGS.mode:
assert (params.task.train_data.feature_shape ==
params.task.validation_data.feature_shape), (
f'train {params.task.train_data.feature_shape} != validate '
f'{params.task.validation_data.feature_shape}')
if 'assemblenet' in FLAGS.experiment:
if 'eval' in FLAGS.mode:
# Use the feature shape in validation_data for all jobs. The number of
# frames in train_data will be used to construct the Assemblenet model.
params.task.model.backbone.assemblenet.num_frames = params.task.validation_data.feature_shape[
0]
shape = params.task.validation_data.feature_shape
else:
params.task.model.backbone.assemblenet.num_frames = params.task.train_data.feature_shape[
0]
shape = params.task.train_data.feature_shape
logging.info('mode %r num_frames %r feature shape %r', FLAGS.mode,
params.task.model.backbone.assemblenet.num_frames, shape)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment