Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c609ff2e
Commit
c609ff2e
authored
Apr 19, 2021
by
Yeqing Li
Committed by
A. Unique TensorFlower
Apr 19, 2021
Browse files
Internal change
PiperOrigin-RevId: 369249071
parent
56cda9c5
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1811 additions
and
0 deletions
+1811
-0
official/vision/beta/projects/assemblenet/README.md
official/vision/beta/projects/assemblenet/README.md
+14
-0
official/vision/beta/projects/assemblenet/configs/assemblenet.py
...l/vision/beta/projects/assemblenet/configs/assemblenet.py
+225
-0
official/vision/beta/projects/assemblenet/modeling/assemblenet.py
.../vision/beta/projects/assemblenet/modeling/assemblenet.py
+1073
-0
official/vision/beta/projects/assemblenet/modeling/rep_flow_2d_layer.py
...n/beta/projects/assemblenet/modeling/rep_flow_2d_layer.py
+405
-0
official/vision/beta/projects/assemblenet/train.py
official/vision/beta/projects/assemblenet/train.py
+94
-0
No files found.
official/vision/beta/projects/assemblenet/README.md
0 → 100644
View file @
c609ff2e
# AssembleNet and AssembleNet++
This repository is the official implementations of the following papers.
[

](https://arxiv.org/abs/1905.13209)
[
AssembleNet: Searching for Multi-Stream Neural Connectivity in Video
Architectures
](
https://arxiv.org/abs/1905.13209
)
[

](https://arxiv.org/abs/1905.13209)
[
AssembleNet++: Assembling Modality Representations via Attention
Connections
](
https://arxiv.org/abs/2008.08072
)
**DISCLAIMER**
: AssembleNet++ implementation is still under development.
No support will be provided during the development phase.
official/vision/beta/projects/assemblenet/configs/assemblenet.py
0 → 100644
View file @
c609ff2e
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Definitions for AssembleNet/++ structures.
This structure is a `list` corresponding to a graph representation of the
network, where a node is a convolutional block and an edge specifies a
connection from one block to another.
Each node itself (in the structure list) is a list with the following format:
[block_level, [list_of_input_blocks], number_filter, temporal_dilation,
spatial_stride]. [list_of_input_blocks] should be the list of node indexes whose
values are less than the index of the node itself. The 'stems' of the network
directly taking raw inputs follow a different node format:
[stem_type, temporal_dilation]. The stem_type is -1 for RGB stem and is -2 for
optical flow stem. The stem_type -3 is reserved for the object segmentation
input.
In AssembleNet++lite, instead of passing a single `int` for number_filter, we
pass a list/tuple of three `int`s. They specify the number of channels to be
used for each layer in the inverted bottleneck modules.
The structure_weights specify the learned connection weights.
"""
from
typing
import
List
,
Tuple
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.vision.beta.configs
import
backbones_3d
from
official.vision.beta.configs
import
common
from
official.vision.beta.configs.google
import
video_classification
@
dataclasses
.
dataclass
class
BlockSpec
(
hyperparams
.
Config
):
level
:
int
=
-
1
input_blocks
:
Tuple
[
int
,
...]
=
tuple
()
num_filters
:
int
=
-
1
temporal_dilation
:
int
=
1
spatial_stride
:
int
=
1
input_block_weight
:
Tuple
[
float
,
...]
=
tuple
()
def
flat_lists_to_blocks
(
model_structures
,
model_edge_weights
):
"""Transforms the raw list structure configs to BlockSpec tuple."""
blocks
=
[]
for
node
,
edge_weights
in
zip
(
model_structures
,
model_edge_weights
):
if
node
[
0
]
<
0
:
block
=
BlockSpec
(
level
=
node
[
0
],
temporal_dilation
=
node
[
1
])
else
:
block
=
BlockSpec
(
level
=
node
[
0
],
input_blocks
=
node
[
1
],
num_filters
=
node
[
2
],
temporal_dilation
=
node
[
3
],
spatial_stride
=
node
[
4
])
if
edge_weights
:
assert
len
(
edge_weights
[
0
])
==
len
(
block
.
input_blocks
),
(
f
'
{
len
(
edge_weights
[
0
])
}
!=
{
len
(
block
.
input_blocks
)
}
at block '
f
'
{
block
}
weight
{
edge_weights
}
'
)
block
.
input_block_weight
=
tuple
(
edge_weights
[
0
])
blocks
.
append
(
block
)
return
tuple
(
blocks
)
def
blocks_to_flat_lists
(
blocks
:
List
[
BlockSpec
]):
"""Transforms BlockSpec tuple to the raw list structure configs."""
# pylint: disable=g-complex-comprehension
# pylint: disable=g-long-ternary
model_structure
=
[[
b
.
level
,
list
(
b
.
input_blocks
),
b
.
num_filters
,
b
.
temporal_dilation
,
b
.
spatial_stride
,
0
]
if
b
.
level
>=
0
else
[
b
.
level
,
b
.
temporal_dilation
]
for
b
in
blocks
]
model_edge_weights
=
[
[
list
(
b
.
input_block_weight
)]
if
b
.
input_block_weight
else
[]
for
b
in
blocks
]
return
model_structure
,
model_edge_weights
# AssembleNet structure for 50/101 layer models, found using evolution with the
# Moments-in-Time dataset. This is the structure used for the experiments in the
# AssembleNet paper. The learned connectivity weights are also provided.
asn50_structure
=
[[
-
1
,
4
],
[
-
1
,
4
],
[
-
2
,
1
],
[
-
2
,
1
],
[
0
,
[
1
],
32
,
1
,
1
,
0
],
[
0
,
[
0
],
32
,
4
,
1
,
0
],
[
0
,
[
0
,
1
,
2
,
3
],
32
,
1
,
1
,
0
],
[
0
,
[
2
,
3
],
32
,
2
,
1
,
0
],
[
1
,
[
0
,
4
,
5
,
6
,
7
],
64
,
2
,
2
,
0
],
[
1
,
[
0
,
2
,
4
,
7
],
64
,
1
,
2
,
0
],
[
1
,
[
0
,
5
,
7
],
64
,
4
,
2
,
0
],
[
1
,
[
0
,
5
],
64
,
1
,
2
,
0
],
[
2
,
[
4
,
8
,
10
,
11
],
256
,
1
,
2
,
0
],
[
2
,
[
8
,
9
],
256
,
4
,
2
,
0
],
[
3
,
[
12
,
13
],
512
,
2
,
2
,
0
]]
asn101_structure
=
[[
-
1
,
4
],
[
-
1
,
4
],
[
-
2
,
1
],
[
-
2
,
1
],
[
0
,
[
1
],
32
,
1
,
1
,
0
],
[
0
,
[
0
],
32
,
4
,
1
,
0
],
[
0
,
[
0
,
1
,
2
,
3
],
32
,
1
,
1
,
0
],
[
0
,
[
2
,
3
],
32
,
2
,
1
,
0
],
[
1
,
[
0
,
4
,
5
,
6
,
7
],
64
,
2
,
2
,
0
],
[
1
,
[
0
,
2
,
4
,
7
],
64
,
1
,
2
,
0
],
[
1
,
[
0
,
5
,
7
],
64
,
4
,
2
,
0
],
[
1
,
[
0
,
5
],
64
,
1
,
2
,
0
],
[
2
,
[
4
,
8
,
10
,
11
],
192
,
1
,
2
,
0
],
[
2
,
[
8
,
9
],
192
,
4
,
2
,
0
],
[
3
,
[
12
,
13
],
256
,
2
,
2
,
0
]]
asn_structure_weights
=
[
[],
[],
[],
[],
[],
[],
[[
0.13810564577579498
,
0.8465337157249451
,
0.3072969317436218
,
0.2867436408996582
]],
[[
0.5846117734909058
,
0.6066334843635559
]],
[[
0.16382087767124176
,
0.8852924704551697
,
0.4039595425128937
,
0.6823437809944153
,
0.5331538319587708
]],
[[
0.028569204732775688
,
0.10333596915006638
,
0.7517264485359192
,
0.9260114431381226
]],
[[
0.28832191228866577
,
0.7627848982810974
,
0.404977947473526
]],
[[
0.23474831879138947
,
0.7841425538063049
]],
[[
0.27616503834724426
,
0.9514784812927246
,
0.6568767428398132
,
0.9547983407974243
]],
[[
0.5047007203102112
,
0.8876819610595703
]],
[[
0.9892204403877258
,
0.8454614877700806
]]
]
# AssembleNet++ structure for 50 layer models, found with the Charades dataset.
# This is the model used in the experiments in the AssembleNet++ paper.
# Note that, in order the build AssembleNet++ with this structure, you also need
# to feed 'object segmentation input' to the network indicated as [-3, 4]. It's
# the 5th block in the architecture.
# If you don't plan to use the object input but want to still benefit from
# peer-attention in AssembleNet++ (with RGB and OF), please use the above
# AssembleNet-50 model instead with assemblenet_plus.py code.
full_asnp50_structure
=
[[
-
1
,
2
],
[
-
1
,
4
],
[
-
2
,
2
],
[
-
2
,
1
],
[
-
3
,
4
],
[
0
,
[
0
,
1
,
2
,
3
,
4
],
32
,
1
,
1
,
0
],
[
0
,
[
0
,
1
,
4
],
32
,
4
,
1
,
0
],
[
0
,
[
2
,
3
,
4
],
32
,
8
,
1
,
0
],
[
0
,
[
2
,
3
,
4
],
32
,
1
,
1
,
0
],
[
1
,
[
0
,
1
,
2
,
4
,
5
,
6
,
7
,
8
],
64
,
4
,
2
,
0
],
[
1
,
[
2
,
3
,
4
,
7
,
8
],
64
,
1
,
2
,
0
],
[
1
,
[
0
,
4
,
5
,
6
,
7
],
128
,
8
,
2
,
0
],
[
2
,
[
4
,
11
],
256
,
8
,
2
,
0
],
[
2
,
[
2
,
3
,
4
,
5
,
6
,
7
,
8
,
10
,
11
],
256
,
4
,
2
,
0
],
[
3
,
[
12
,
13
],
512
,
2
,
2
,
0
]]
full_asnp_structure_weights
=
[[],
[],
[],
[],
[],
[[
0.6143830418586731
,
0.7111759185791016
,
0.19351491332054138
,
0.1701001077890396
,
0.7178536653518677
]],
[[
0.5755624771118164
,
0.5644599795341492
,
0.7128658294677734
]],
[[
0.26563042402267456
,
0.3033692538738251
,
0.8244096636772156
]],
[[
0.07013848423957825
,
0.07905343919992447
,
0.8767927885055542
]],
[[
0.5008697509765625
,
0.5020178556442261
,
0.49819135665893555
,
0.5015180706977844
,
0.4987695813179016
,
0.4990265369415283
,
0.499239057302475
,
0.4974501430988312
]],
[[
0.47034338116645813
,
0.4694305658340454
,
0.767791748046875
,
0.5539310574531555
,
0.4520096182823181
]],
[[
0.2769702076911926
,
0.8116549253463745
,
0.597356915473938
,
0.6585626602172852
,
0.5915306210517883
]],
[[
0.501274824142456
,
0.5016682147979736
]],
[[
0.0866393893957138
,
0.08469288796186447
,
0.9739039540290833
,
0.058271341025829315
,
0.08397126197814941
,
0.10285478830337524
,
0.18506969511508942
,
0.23874442279338837
,
0.9188644886016846
]],
[[
0.4174623489379883
,
0.5844835638999939
]]]
# pylint: disable=line-too-long
# AssembleNet++lite structure using inverted bottleneck blocks. By specifing
# the connection weights as [], the model could alos automatically learn the
# connection weights during its training.
asnp_lite_structure
=
[[
-
1
,
1
],
[
-
2
,
1
],
[
0
,
[
0
,
1
],
[
27
,
27
,
12
],
1
,
2
,
0
],
[
0
,
[
0
,
1
],
[
27
,
27
,
12
],
4
,
2
,
0
],
[
1
,
[
0
,
1
,
2
,
3
],
[
54
,
54
,
24
],
2
,
2
,
0
],
[
1
,
[
0
,
1
,
2
,
3
],
[
54
,
54
,
24
],
1
,
2
,
0
],
[
1
,
[
0
,
1
,
2
,
3
],
[
54
,
54
,
24
],
4
,
2
,
0
],
[
1
,
[
0
,
1
,
2
,
3
],
[
54
,
54
,
24
],
1
,
2
,
0
],
[
2
,
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
152
,
152
,
68
],
1
,
2
,
0
],
[
2
,
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
152
,
152
,
68
],
4
,
2
,
0
],
[
3
,
[
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
],
[
432
,
432
,
192
],
2
,
2
,
0
]]
asnp_lite_structure_weights
=
[[],
[],
[[
0.19914183020591736
,
0.9278576374053955
]],
[[
0.010816320776939392
,
0.888792097568512
]],
[[
0.9473835825920105
,
0.6303419470787048
,
0.1704932451248169
,
0.05950307101011276
]],
[[
0.9560931324958801
,
0.7898273468017578
,
0.36138781905174255
,
0.07344610244035721
]],
[[
0.9213919043540955
,
0.13418640196323395
,
0.8371981978416443
,
0.07936054468154907
]],
[[
0.9441559910774231
,
0.9435100555419922
,
0.7253988981246948
,
0.13498817384243011
]],
[[
0.9964852333068848
,
0.8427878618240356
,
0.8895476460456848
,
0.11014710366725922
,
0.6270533204078674
,
0.44782018661499023
,
0.61344975233078
,
0.44898226857185364
]],
[[
0.9970942735671997
,
0.7105681896209717
,
0.5078442096710205
,
0.0951600968837738
,
0.624282717704773
,
0.8527252674102783
,
0.8105692863464355
,
0.7857823967933655
]],
[[
0.6180334091186523
,
0.11882413923740387
,
0.06102970987558365
,
0.04484326392412186
,
0.05602221190929413
,
0.052324872463941574
,
0.9969874024391174
,
0.9987731575965881
]]]
# pylint: disable=line-too-long
@
dataclasses
.
dataclass
class
AssembleNet
(
hyperparams
.
Config
):
model_id
:
str
=
'50'
num_frames
:
int
=
0
combine_method
:
str
=
'sigmoid'
blocks
:
Tuple
[
BlockSpec
,
...]
=
tuple
()
@
dataclasses
.
dataclass
class
Backbone3D
(
backbones_3d
.
Backbone3D
):
"""Configuration for backbones.
Attributes:
type: 'str', type of backbone be used, on the of fields below.
resnet: resnet3d backbone config.
"""
type
:
str
=
'assemblenet'
assemblenet
:
AssembleNet
=
AssembleNet
()
@
dataclasses
.
dataclass
class
AssembleNetModel
(
video_classification
.
VideoClassificationModel
):
"""The AssembleNet model config."""
model_type
:
str
=
'assemblenet'
backbone
:
Backbone3D
=
Backbone3D
()
norm_activation
:
common
.
NormActivation
=
common
.
NormActivation
(
norm_momentum
=
0.99
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
True
)
max_pool_preditions
:
bool
=
False
@
exp_factory
.
register_config_factory
(
'assemblenet50_kinetics600'
)
def
assemblenet_kinetics600
()
->
cfg
.
ExperimentConfig
:
"""Video classification on Videonet with assemblenet."""
exp
=
video_classification
.
video_classification_kinetics600
()
feature_shape
=
(
32
,
224
,
224
,
3
)
exp
.
task
.
train_data
.
global_batch_size
=
1024
exp
.
task
.
validation_data
.
global_batch_size
=
32
exp
.
task
.
train_data
.
feature_shape
=
feature_shape
exp
.
task
.
validation_data
.
feature_shape
=
(
120
,
224
,
224
,
3
)
exp
.
task
.
train_data
.
dtype
=
'bfloat16'
exp
.
task
.
validation_data
.
dtype
=
'bfloat16'
model
=
AssembleNetModel
()
model
.
backbone
.
assemblenet
.
model_id
=
'50'
model
.
backbone
.
assemblenet
.
blocks
=
flat_lists_to_blocks
(
asn50_structure
,
asn_structure_weights
)
model
.
backbone
.
assemblenet
.
num_frames
=
feature_shape
[
0
]
exp
.
task
.
model
=
model
assert
exp
.
task
.
model
.
backbone
.
assemblenet
.
num_frames
>
0
,
(
f
'backbone num_frames '
f
'
{
exp
.
task
.
model
.
backbone
.
assemblenet
}
'
)
return
exp
official/vision/beta/projects/assemblenet/modeling/assemblenet.py
0 → 100644
View file @
c609ff2e
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Contains definitions for the AssembleNet [1] models.
Requires the AssembleNet architecture to be specified in
FLAGS.model_structure (and optionally FLAGS.model_edge_weights).
This structure is a list corresponding to a graph representation of the
network, where a node is a convolutional block and an edge specifies a
connection from one block to another as described in [1].
Each node itself (in the structure list) is a list with the following format:
[block_level, [list_of_input_blocks], number_filter, temporal_dilation,
spatial_stride]. [list_of_input_blocks] should be the list of node indexes whose
values are less than the index of the node itself. The 'stems' of the network
directly taking raw inputs follow a different node format:
[stem_type, temporal_dilation]. The stem_type is -1 for RGB stem and is -2 for
optical flow stem.
Also note that the codes in this file could be used for one-shot differentiable
connection search by (1) giving an overly connected structure as
FLAGS.model_structure and by (2) setting FLAGS.model_edge_weights to be '[]'.
The 'agg_weights' variables will specify which connections are needed and which
are not, once trained.
[1] Michael S. Ryoo, AJ Piergiovanni, Mingxing Tan, Anelia Angelova,
AssembleNet: Searching for Multi-Stream Neural Connectivity in Video
Architectures. ICLR 2020
https://arxiv.org/abs/1905.13209
It uses (2+1)D convolutions for video representations. The main AssembleNet
takes a 4-D (N*T)HWC tensor as an input (i.e., the batch dim and time dim are
mixed), and it reshapes a tensor to NT(H*W)C whenever a 1-D temporal conv. is
necessary. This is to run this on TPU efficiently.
"""
import
functools
import
math
from
typing
import
Any
,
Mapping
,
List
,
Callable
,
Optional
from
absl
import
logging
import
numpy
as
np
import
tensorflow
as
tf
from
official.vision.beta.modeling
import
factory_3d
as
model_factory
from
official.vision.beta.modeling.backbones
import
factory
as
backbone_factory
from
official.vision.beta.projects.assemblenet.configs
import
assemblenet
as
cfg
from
official.vision.beta.projects.assemblenet.modeling
import
rep_flow_2d_layer
as
rf
layers
=
tf
.
keras
.
layers
intermediate_channel_size
=
[
64
,
128
,
256
,
512
]
def
fixed_padding
(
inputs
,
kernel_size
):
"""Pads the input along the spatial dimensions independently of input size.
Args:
inputs: `Tensor` of size `[batch, channels, height, width]` or `[batch,
height, width, channels]` depending on `data_format`.
kernel_size: `int` kernel size to be used for `conv2d` or max_pool2d`
operations. Should be a positive integer.
Returns:
A padded `Tensor` of the same `data_format` with size either intact
(if `kernel_size == 1`) or padded (if `kernel_size > 1`).
"""
data_format
=
tf
.
keras
.
backend
.
image_data_format
()
pad_total
=
kernel_size
-
1
pad_beg
=
pad_total
//
2
pad_end
=
pad_total
-
pad_beg
if
data_format
==
'channels_first'
:
padded_inputs
=
tf
.
pad
(
inputs
,
[[
0
,
0
],
[
0
,
0
],
[
pad_beg
,
pad_end
],
[
pad_beg
,
pad_end
]])
else
:
padded_inputs
=
tf
.
pad
(
inputs
,
[[
0
,
0
],
[
pad_beg
,
pad_end
],
[
pad_beg
,
pad_end
],
[
0
,
0
]])
return
padded_inputs
def
reshape_temporal_conv1d_bn
(
inputs
:
tf
.
Tensor
,
filters
:
int
,
kernel_size
:
int
,
num_frames
:
int
=
32
,
temporal_dilation
:
int
=
1
,
bn_decay
:
float
=
rf
.
BATCH_NORM_DECAY
,
bn_epsilon
:
float
=
rf
.
BATCH_NORM_EPSILON
,
use_sync_bn
:
bool
=
False
):
"""Performs 1D temporal conv.
followed by batch normalization with reshaping.
Args:
inputs: `Tensor` of size `[batch*time, height, width, channels]`. Only
supports 'channels_last' as the data format.
filters: `int` number of filters in the convolution.
kernel_size: `int` kernel size to be used for `conv2d` or max_pool2d`
operations. Should be a positive integer.
num_frames: `int` number of frames in the input tensor.
temporal_dilation: `int` temporal dilatioin size for the 1D conv.
bn_decay: `float` batch norm decay parameter to use.
bn_epsilon: `float` batch norm epsilon parameter to use.
use_sync_bn: use synchronized batch norm for TPU.
Returns:
A padded `Tensor` of the same `data_format` with size either intact
(if `kernel_size == 1`) or padded (if `kernel_size > 1`).
"""
data_format
=
tf
.
keras
.
backend
.
image_data_format
()
assert
data_format
==
'channels_last'
feature_shape
=
inputs
.
shape
inputs
=
tf
.
reshape
(
inputs
,
[
-
1
,
num_frames
,
feature_shape
[
1
]
*
feature_shape
[
2
],
feature_shape
[
3
]])
if
temporal_dilation
==
1
:
inputs
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
filters
,
kernel_size
=
(
kernel_size
,
1
),
strides
=
1
,
padding
=
'SAME'
,
use_bias
=
False
,
kernel_initializer
=
tf
.
keras
.
initializers
.
VarianceScaling
())(
inputs
=
inputs
)
else
:
inputs
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
filters
,
kernel_size
=
(
kernel_size
,
1
),
strides
=
1
,
padding
=
'SAME'
,
dilation_rate
=
(
temporal_dilation
,
1
),
use_bias
=
False
,
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
math
.
sqrt
(
2.0
/
(
kernel_size
*
feature_shape
[
3
]))))(
inputs
=
inputs
)
num_channel
=
inputs
.
shape
[
3
]
inputs
=
tf
.
reshape
(
inputs
,
[
-
1
,
feature_shape
[
1
],
feature_shape
[
2
],
num_channel
])
inputs
=
rf
.
build_batch_norm
(
bn_decay
=
bn_decay
,
bn_epsilon
=
bn_epsilon
,
use_sync_bn
=
use_sync_bn
)(
inputs
)
inputs
=
tf
.
nn
.
relu
(
inputs
)
return
inputs
def
conv2d_fixed_padding
(
inputs
:
tf
.
Tensor
,
filters
:
int
,
kernel_size
:
int
,
strides
:
int
):
"""Strided 2-D convolution with explicit padding.
The padding is consistent and is based only on `kernel_size`, not on the
dimensions of `inputs` (as opposed to using `tf.keras.layers.Conv2D` alone).
Args:
inputs: `Tensor` of size `[batch, channels, height_in, width_in]`.
filters: `int` number of filters in the convolution.
kernel_size: `int` size of the kernel to be used in the convolution.
strides: `int` strides of the convolution.
Returns:
A `Tensor` of shape `[batch, filters, height_out, width_out]`.
"""
if
strides
>
1
:
inputs
=
fixed_padding
(
inputs
,
kernel_size
)
return
tf
.
keras
.
layers
.
Conv2D
(
filters
=
filters
,
kernel_size
=
kernel_size
,
strides
=
strides
,
padding
=
(
'SAME'
if
strides
==
1
else
'VALID'
),
use_bias
=
False
,
kernel_initializer
=
tf
.
keras
.
initializers
.
VarianceScaling
())(
inputs
=
inputs
)
def
conv3d_same_padding
(
inputs
:
tf
.
Tensor
,
filters
:
int
,
kernel_size
:
int
,
strides
:
int
,
temporal_dilation
:
int
=
1
,
do_2d_conv
:
bool
=
False
):
"""3D convolution layer wrapper.
Uses conv3d function.
Args:
inputs: 5D `Tensor` following the data_format.
filters: `int` number of filters in the convolution.
kernel_size: `int` size of the kernel to be used in the convolution.
strides: `int` strides of the convolution.
temporal_dilation: `int` temporal dilatioin size for the 1D conv.
do_2d_conv: `bool` indicating whether to do 2d conv. If false, do 3D conv.
Returns:
A `Tensor` of shape `[batch, time_in, height_in, width_in, channels]`.
"""
if
isinstance
(
kernel_size
,
int
):
if
do_2d_conv
:
kernel_size
=
[
1
,
kernel_size
,
kernel_size
]
else
:
kernel_size
=
[
kernel_size
,
kernel_size
,
kernel_size
]
return
tf
.
keras
.
layers
.
Conv3D
(
filters
=
filters
,
kernel_size
=
kernel_size
,
strides
=
[
1
,
strides
,
strides
],
padding
=
'SAME'
,
dilation_rate
=
[
temporal_dilation
,
1
,
1
],
use_bias
=
False
,
kernel_initializer
=
tf
.
keras
.
initializers
.
VarianceScaling
())(
inputs
=
inputs
)
def
bottleneck_block_interleave
(
inputs
:
tf
.
Tensor
,
filters
:
int
,
inter_filters
:
int
,
strides
:
int
,
use_projection
:
bool
=
False
,
num_frames
:
int
=
32
,
temporal_dilation
:
int
=
1
,
bn_decay
:
float
=
rf
.
BATCH_NORM_DECAY
,
bn_epsilon
:
float
=
rf
.
BATCH_NORM_EPSILON
,
use_sync_bn
:
bool
=
False
,
step
=
1
):
"""Interleaves a standard 2D residual module and (2+1)D residual module.
Bottleneck block variant for residual networks with BN after convolutions.
Args:
inputs: `Tensor` of size `[batch*time, channels, height, width]`.
filters: `int` number of filters for the first conv. layer. The last conv.
layer will use 4 times as many filters.
inter_filters: `int` number of filters for the second conv. layer.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input spatially.
use_projection: `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
num_frames: `int` number of frames in the input tensor.
temporal_dilation: `int` temporal dilatioin size for the 1D conv.
bn_decay: `float` batch norm decay parameter to use.
bn_epsilon: `float` batch norm epsilon parameter to use.
use_sync_bn: use synchronized batch norm for TPU.
step: `int` to decide whether to put 2D module or (2+1)D module.
Returns:
The output `Tensor` of the block.
"""
if
strides
>
1
and
not
use_projection
:
raise
ValueError
(
'strides > 1 requires use_projections=True, otherwise the '
'inputs and shortcut will have shape mismatch'
)
shortcut
=
inputs
if
use_projection
:
# Projection shortcut only in first block within a group. Bottleneck blocks
# end with 4 times the number of filters.
filters_out
=
4
*
filters
shortcut
=
conv2d_fixed_padding
(
inputs
=
inputs
,
filters
=
filters_out
,
kernel_size
=
1
,
strides
=
strides
)
shortcut
=
rf
.
build_batch_norm
(
bn_decay
=
bn_decay
,
bn_epsilon
=
bn_epsilon
,
use_sync_bn
=
use_sync_bn
)(
shortcut
)
if
step
%
2
==
1
:
k
=
3
inputs
=
reshape_temporal_conv1d_bn
(
inputs
=
inputs
,
filters
=
filters
,
kernel_size
=
k
,
num_frames
=
num_frames
,
temporal_dilation
=
temporal_dilation
,
bn_decay
=
bn_decay
,
bn_epsilon
=
bn_epsilon
,
use_sync_bn
=
use_sync_bn
)
else
:
inputs
=
conv2d_fixed_padding
(
inputs
=
inputs
,
filters
=
filters
,
kernel_size
=
1
,
strides
=
1
)
inputs
=
rf
.
build_batch_norm
(
bn_decay
=
bn_decay
,
bn_epsilon
=
bn_epsilon
,
use_sync_bn
=
use_sync_bn
)(
inputs
)
inputs
=
tf
.
nn
.
relu
(
inputs
)
inputs
=
conv2d_fixed_padding
(
inputs
=
inputs
,
filters
=
inter_filters
,
kernel_size
=
3
,
strides
=
strides
)
inputs
=
rf
.
build_batch_norm
(
bn_decay
=
bn_decay
,
bn_epsilon
=
bn_epsilon
,
use_sync_bn
=
use_sync_bn
)(
inputs
)
inputs
=
tf
.
nn
.
relu
(
inputs
)
inputs
=
conv2d_fixed_padding
(
inputs
=
inputs
,
filters
=
4
*
filters
,
kernel_size
=
1
,
strides
=
1
)
inputs
=
rf
.
build_batch_norm
(
init_zero
=
True
,
bn_decay
=
bn_decay
,
bn_epsilon
=
bn_epsilon
,
use_sync_bn
=
use_sync_bn
)(
inputs
)
return
tf
.
nn
.
relu
(
inputs
+
shortcut
)
def
block_group
(
inputs
:
tf
.
Tensor
,
filters
:
int
,
block_fn
:
Callable
[...,
tf
.
Tensor
],
blocks
:
int
,
strides
:
int
,
name
,
block_level
,
num_frames
=
32
,
temporal_dilation
=
1
):
"""Creates one group of blocks for the AssembleNett model.
Args:
inputs: `Tensor` of size `[batch*time, channels, height, width]`.
filters: `int` number of filters for the first convolution of the layer.
block_fn: `function` for the block to use within the model
blocks: `int` number of blocks contained in the layer.
strides: `int` stride to use for the first convolution of the layer. If
greater than 1, this layer will downsample the input.
name: `str` name for the Tensor output of the block layer.
block_level: `int` block level in AssembleNet.
num_frames: `int` number of frames in the input tensor.
temporal_dilation: `int` temporal dilatioin size for the 1D conv.
Returns:
The output `Tensor` of the block layer.
"""
# Only the first block per block_group uses projection shortcut and strides.
inputs
=
block_fn
(
inputs
,
filters
,
intermediate_channel_size
[
block_level
],
strides
,
use_projection
=
True
,
num_frames
=
num_frames
,
temporal_dilation
=
temporal_dilation
,
step
=
0
)
for
i
in
range
(
1
,
blocks
):
inputs
=
block_fn
(
inputs
,
filters
,
intermediate_channel_size
[
block_level
],
1
,
num_frames
=
num_frames
,
temporal_dilation
=
temporal_dilation
,
step
=
i
)
return
tf
.
identity
(
inputs
,
name
)
def
spatial_resize_and_concat
(
inputs
):
"""Concatenates multiple different sized tensors channel-wise.
Args:
inputs: A list of `Tensors` of size `[batch*time, channels, height, width]`.
Returns:
The output `Tensor` after concatenation.
"""
data_format
=
tf
.
keras
.
backend
.
image_data_format
()
assert
data_format
==
'channels_last'
# Do nothing if only 1 input
if
len
(
inputs
)
==
1
:
return
inputs
[
0
]
if
data_format
!=
'channels_last'
:
return
inputs
# get smallest spatial size and largest channels
sm_size
=
[
1000
,
1000
]
for
inp
in
inputs
:
# assume batch X height x width x channels
sm_size
[
0
]
=
min
(
sm_size
[
0
],
inp
.
shape
[
1
])
sm_size
[
1
]
=
min
(
sm_size
[
1
],
inp
.
shape
[
2
])
for
i
in
range
(
len
(
inputs
)):
if
inputs
[
i
].
shape
[
1
]
!=
sm_size
[
0
]
or
inputs
[
i
].
shape
[
2
]
!=
sm_size
[
1
]:
ratio
=
(
inputs
[
i
].
shape
[
1
]
+
1
)
//
sm_size
[
0
]
inputs
[
i
]
=
tf
.
keras
.
layers
.
MaxPool2D
([
ratio
,
ratio
],
ratio
,
padding
=
'same'
)(
inputs
[
i
])
return
tf
.
concat
(
inputs
,
3
)
class
_ApplyEdgeWeight
(
layers
.
Layer
):
"""Multiply weight on each input tensor.
A weight is assigned for each connection (i.e., each input tensor). This layer
is used by the multi_connection_fusion to compute the weighted inputs.
"""
def
__init__
(
self
,
weights_shape
,
index
:
int
=
None
,
use_5d_mode
:
bool
=
False
,
model_edge_weights
:
Optional
[
List
[
Any
]]
=
None
,
**
kwargs
):
"""Constructor.
Args:
weights_shape: shape of the weights. Should equals to [len(inputs)].
index: `int` index of the block within the AssembleNet architecture. Used
for summation weight initial loading.
use_5d_mode: `bool` indicating whether the inputs are in 5D tensor or 4D.
model_edge_weights: AssembleNet model structure connection weights in the
string format.
**kwargs: pass through arguments.
"""
super
(
_ApplyEdgeWeight
,
self
).
__init__
(
**
kwargs
)
self
.
_weights_shape
=
weights_shape
self
.
_index
=
index
self
.
_use_5d_mode
=
use_5d_mode
self
.
_model_edge_weights
=
model_edge_weights
data_format
=
tf
.
keras
.
backend
.
image_data_format
()
assert
data_format
==
'channels_last'
def
get_config
(
self
):
config
=
{
'weights_shape'
:
self
.
_weights_shape
,
'index'
:
self
.
_index
,
'use_5d_mode'
:
self
.
_use_5d_mode
,
'model_edge_weights'
:
self
.
_model_edge_weights
,
}
base_config
=
super
(
_ApplyEdgeWeight
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
build
(
self
,
input_shape
:
tf
.
TensorShape
):
if
self
.
_weights_shape
[
0
]
==
1
:
self
.
_edge_weights
=
1.0
return
if
self
.
_index
is
None
or
not
self
.
_model_edge_weights
:
self
.
_edge_weights
=
self
.
add_weight
(
shape
=
self
.
_weights_shape
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
mean
=
0.0
,
stddev
=
0.01
),
trainable
=
True
,
name
=
'agg_weights'
)
else
:
initial_weights_after_sigmoid
=
np
.
asarray
(
self
.
_model_edge_weights
[
self
.
_index
][
0
]).
astype
(
'float32'
)
# Initial_weights_after_sigmoid is never 0, as the initial weights are
# based the results of a successful connectivity search.
initial_weights
=
-
np
.
log
(
1.
/
initial_weights_after_sigmoid
-
1.
)
self
.
_edge_weights
=
self
.
add_weight
(
shape
=
self
.
_weights_shape
,
initializer
=
tf
.
constant_initializer
(
initial_weights
),
trainable
=
False
,
name
=
'agg_weights'
)
def
call
(
self
,
inputs
:
List
[
tf
.
Tensor
],
training
:
bool
=
None
)
->
Mapping
[
Any
,
List
[
tf
.
Tensor
]]:
use_5d_mode
=
self
.
_use_5d_mode
dtype
=
inputs
[
0
].
dtype
assert
len
(
inputs
)
>
1
if
use_5d_mode
:
h_channel_loc
=
2
else
:
h_channel_loc
=
1
# get smallest spatial size and largest channels
sm_size
=
[
10000
,
10000
]
lg_channel
=
0
for
inp
in
inputs
:
# assume batch X height x width x channels
sm_size
[
0
]
=
min
(
sm_size
[
0
],
inp
.
shape
[
h_channel_loc
])
sm_size
[
1
]
=
min
(
sm_size
[
1
],
inp
.
shape
[
h_channel_loc
+
1
])
lg_channel
=
max
(
lg_channel
,
inp
.
shape
[
-
1
])
# loads or creates weight variables to fuse multiple inputs
weights
=
tf
.
math
.
sigmoid
(
tf
.
cast
(
self
.
_edge_weights
,
dtype
))
# Compute weighted inputs. We group inputs with the same channels.
per_channel_inps
=
dict
({
0
:
[]})
for
i
,
inp
in
enumerate
(
inputs
):
if
inp
.
shape
[
h_channel_loc
]
!=
sm_size
[
0
]
or
inp
.
shape
[
h_channel_loc
+
1
]
!=
sm_size
[
1
]:
# pylint: disable=line-too-long
assert
sm_size
[
0
]
!=
0
ratio
=
(
inp
.
shape
[
h_channel_loc
]
+
1
)
//
sm_size
[
0
]
if
use_5d_mode
:
inp
=
tf
.
keras
.
layers
.
MaxPool3D
([
1
,
ratio
,
ratio
],
[
1
,
ratio
,
ratio
],
padding
=
'same'
)(
inp
)
else
:
inp
=
tf
.
keras
.
layers
.
MaxPool2D
([
ratio
,
ratio
],
ratio
,
padding
=
'same'
)(
inp
)
weights
=
tf
.
cast
(
weights
,
inp
.
dtype
)
if
inp
.
shape
[
-
1
]
in
per_channel_inps
:
per_channel_inps
[
inp
.
shape
[
-
1
]].
append
(
weights
[
i
]
*
inp
)
else
:
per_channel_inps
.
update
({
inp
.
shape
[
-
1
]:
[
weights
[
i
]
*
inp
]})
return
per_channel_inps
def
multi_connection_fusion
(
inputs
:
List
[
tf
.
Tensor
],
index
:
int
=
None
,
use_5d_mode
:
bool
=
False
,
model_edge_weights
:
Optional
[
List
[
Any
]]
=
None
):
"""Do weighted summation of multiple different sized tensors.
A weight is assigned for each connection (i.e., each input tensor), and their
summation weights are learned. Uses spatial max pooling and 1x1 conv.
to match their sizes.
Args:
inputs: A `Tensor`. Either 4D or 5D, depending of use_5d_mode.
index: `int` index of the block within the AssembleNet architecture. Used
for summation weight initial loading.
use_5d_mode: `bool` indicating whether the inputs are in 5D tensor or 4D.
model_edge_weights: AssembleNet model structure connection weights in the
string format.
Returns:
The output `Tensor` after concatenation.
"""
if
use_5d_mode
:
h_channel_loc
=
2
conv_function
=
conv3d_same_padding
else
:
h_channel_loc
=
1
conv_function
=
conv2d_fixed_padding
# If only 1 input.
if
len
(
inputs
)
==
1
:
return
inputs
[
0
]
# get smallest spatial size and largest channels
sm_size
=
[
10000
,
10000
]
lg_channel
=
0
for
inp
in
inputs
:
# assume batch X height x width x channels
sm_size
[
0
]
=
min
(
sm_size
[
0
],
inp
.
shape
[
h_channel_loc
])
sm_size
[
1
]
=
min
(
sm_size
[
1
],
inp
.
shape
[
h_channel_loc
+
1
])
lg_channel
=
max
(
lg_channel
,
inp
.
shape
[
-
1
])
per_channel_inps
=
_ApplyEdgeWeight
(
weights_shape
=
[
len
(
inputs
)],
index
=
index
,
use_5d_mode
=
use_5d_mode
,
model_edge_weights
=
model_edge_weights
)(
inputs
)
# Adding 1x1 conv layers (to match channel size) and fusing all inputs.
# We add inputs with the same channels first before applying 1x1 conv to save
# memory.
inps
=
[]
for
key
,
channel_inps
in
per_channel_inps
.
items
():
if
len
(
channel_inps
)
<
1
:
continue
if
len
(
channel_inps
)
==
1
:
if
key
==
lg_channel
:
inp
=
channel_inps
[
0
]
else
:
inp
=
conv_function
(
channel_inps
[
0
],
lg_channel
,
kernel_size
=
1
,
strides
=
1
)
inps
.
append
(
inp
)
else
:
if
key
==
lg_channel
:
inp
=
tf
.
add_n
(
channel_inps
)
else
:
inp
=
conv_function
(
tf
.
add_n
(
channel_inps
),
lg_channel
,
kernel_size
=
1
,
strides
=
1
)
inps
.
append
(
inp
)
return
tf
.
add_n
(
inps
)
def
rgb_conv_stem
(
inputs
,
num_frames
,
filters
,
temporal_dilation
,
bn_decay
:
float
=
rf
.
BATCH_NORM_DECAY
,
bn_epsilon
:
float
=
rf
.
BATCH_NORM_EPSILON
,
use_sync_bn
:
bool
=
False
):
"""Layers for a RGB stem.
Args:
inputs: A `Tensor` of size `[batch*time, height, width, channels]`.
num_frames: `int` number of frames in the input tensor.
filters: `int` number of filters in the convolution.
temporal_dilation: `int` temporal dilatioin size for the 1D conv.
bn_decay: `float` batch norm decay parameter to use.
bn_epsilon: `float` batch norm epsilon parameter to use.
use_sync_bn: use synchronized batch norm for TPU.
Returns:
The output `Tensor`.
"""
data_format
=
tf
.
keras
.
backend
.
image_data_format
()
assert
data_format
==
'channels_last'
if
temporal_dilation
<
1
:
temporal_dilation
=
1
inputs
=
conv2d_fixed_padding
(
inputs
=
inputs
,
filters
=
filters
,
kernel_size
=
7
,
strides
=
2
)
inputs
=
tf
.
identity
(
inputs
,
'initial_conv'
)
inputs
=
rf
.
build_batch_norm
(
bn_decay
=
bn_decay
,
bn_epsilon
=
bn_epsilon
,
use_sync_bn
=
use_sync_bn
)(
inputs
)
inputs
=
tf
.
nn
.
relu
(
inputs
)
inputs
=
reshape_temporal_conv1d_bn
(
inputs
=
inputs
,
filters
=
filters
,
kernel_size
=
5
,
num_frames
=
num_frames
,
temporal_dilation
=
temporal_dilation
,
bn_decay
=
bn_decay
,
bn_epsilon
=
bn_epsilon
,
use_sync_bn
=
use_sync_bn
)
inputs
=
tf
.
keras
.
layers
.
MaxPool2D
(
pool_size
=
3
,
strides
=
2
,
padding
=
'SAME'
)(
inputs
=
inputs
)
inputs
=
tf
.
identity
(
inputs
,
'initial_max_pool'
)
return
inputs
def
flow_conv_stem
(
inputs
,
filters
,
temporal_dilation
,
bn_decay
:
float
=
rf
.
BATCH_NORM_DECAY
,
bn_epsilon
:
float
=
rf
.
BATCH_NORM_EPSILON
,
use_sync_bn
:
bool
=
False
):
"""Layers for an optical flow stem.
Args:
inputs: A `Tensor` of size `[batch*time, height, width, channels]`.
filters: `int` number of filters in the convolution.
temporal_dilation: `int` temporal dilatioin size for the 1D conv.
bn_decay: `float` batch norm decay parameter to use.
bn_epsilon: `float` batch norm epsilon parameter to use.
use_sync_bn: use synchronized batch norm for TPU.
Returns:
The output `Tensor`.
"""
if
temporal_dilation
<
1
:
temporal_dilation
=
1
inputs
=
conv2d_fixed_padding
(
inputs
=
inputs
,
filters
=
filters
,
kernel_size
=
7
,
strides
=
2
)
inputs
=
tf
.
identity
(
inputs
,
'initial_conv'
)
inputs
=
rf
.
build_batch_norm
(
bn_decay
=
bn_decay
,
bn_epsilon
=
bn_epsilon
,
use_sync_bn
=
use_sync_bn
)(
inputs
)
inputs
=
tf
.
nn
.
relu
(
inputs
)
inputs
=
tf
.
keras
.
layers
.
MaxPool2D
(
pool_size
=
2
,
strides
=
2
,
padding
=
'SAME'
)(
inputs
=
inputs
)
inputs
=
tf
.
identity
(
inputs
,
'initial_max_pool'
)
return
inputs
def
multi_stream_heads
(
streams
,
final_nodes
,
num_frames
,
num_classes
,
max_pool_preditions
:
bool
=
False
):
"""Layers for the classification heads.
Args:
streams: A list of 4D `Tensors` following the data_format.
final_nodes: A list of `int` where classification heads will be added.
num_frames: `int` number of frames in the input tensor.
num_classes: `int` number of possible classes for video classification.
max_pool_preditions: Use max-pooling on predictions instead of mean
pooling on features. It helps if you have more than 32 frames.
Returns:
The output `Tensor`.
"""
inputs
=
streams
[
final_nodes
[
0
]]
num_channels
=
inputs
.
shape
[
-
1
]
def
_pool_and_reshape
(
net
):
# The activation is 7x7 so this is a global average pool.
net
=
tf
.
keras
.
layers
.
GlobalAveragePooling2D
()(
inputs
=
net
)
net
=
tf
.
identity
(
net
,
'final_avg_pool0'
)
net
=
tf
.
reshape
(
net
,
[
-
1
,
num_frames
,
num_channels
])
if
not
max_pool_preditions
:
net
=
tf
.
reduce_mean
(
net
,
1
)
return
net
outputs
=
_pool_and_reshape
(
inputs
)
for
i
in
range
(
1
,
len
(
final_nodes
)):
inputs
=
streams
[
final_nodes
[
i
]]
inputs
=
_pool_and_reshape
(
inputs
)
outputs
=
outputs
+
inputs
if
len
(
final_nodes
)
>
1
:
outputs
=
outputs
/
len
(
final_nodes
)
outputs
=
tf
.
keras
.
layers
.
Dense
(
units
=
num_classes
,
kernel_initializer
=
tf
.
random_normal_initializer
(
stddev
=
.
01
))(
inputs
=
outputs
)
outputs
=
tf
.
identity
(
outputs
,
'final_dense0'
)
if
max_pool_preditions
:
pre_logits
=
outputs
/
np
.
sqrt
(
num_frames
)
acts
=
tf
.
nn
.
softmax
(
pre_logits
,
axis
=
1
)
outputs
=
tf
.
math
.
multiply
(
outputs
,
acts
)
outputs
=
tf
.
reduce_sum
(
outputs
,
1
)
return
outputs
class
AssembleNet
(
tf
.
keras
.
Model
):
"""AssembleNet backbone."""
def
__init__
(
self
,
block_fn
,
num_blocks
:
List
[
int
],
num_frames
:
int
,
model_structure
:
List
[
Any
],
input_specs
:
layers
.
InputSpec
=
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
None
,
3
]),
model_edge_weights
:
Optional
[
List
[
Any
]]
=
None
,
bn_decay
:
float
=
rf
.
BATCH_NORM_DECAY
,
bn_epsilon
:
float
=
rf
.
BATCH_NORM_EPSILON
,
use_sync_bn
:
bool
=
False
,
combine_method
:
str
=
'sigmoid'
,
**
kwargs
):
"""Generator for AssembleNet v1 models.
Args:
block_fn: `function` for the block to use within the model. Currently only
has `bottleneck_block_interleave as its option`.
num_blocks: list of 4 `int`s denoting the number of blocks to include in
each of the 4 block groups. Each group consists of blocks that take
inputs of the same resolution.
num_frames: the number of frames in the input tensor.
model_structure: AssembleNet model structure in the string format.
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
Dimension should be `[batch*time, height, width, channels]`.
model_edge_weights: AssembleNet model structure connection weights in the
string format.
bn_decay: `float` batch norm decay parameter to use.
bn_epsilon: `float` batch norm epsilon parameter to use.
use_sync_bn: use synchronized batch norm for TPU.
combine_method: 'str' for the weighted summation to fuse different blocks.
**kwargs: pass through arguments.
"""
inputs
=
tf
.
keras
.
Input
(
shape
=
input_specs
.
shape
[
1
:])
data_format
=
tf
.
keras
.
backend
.
image_data_format
()
# Creation of the model graph.
logging
.
info
(
'model_structure=%r'
,
model_structure
)
logging
.
info
(
'model_structure=%r'
,
model_structure
)
logging
.
info
(
'model_edge_weights=%r'
,
model_edge_weights
)
structure
=
model_structure
original_num_frames
=
num_frames
assert
num_frames
>
0
,
f
'Invalid num_frames
{
num_frames
}
'
grouping
=
{
-
3
:
[],
-
2
:
[],
-
1
:
[],
0
:
[],
1
:
[],
2
:
[],
3
:
[]}
for
i
in
range
(
len
(
structure
)):
grouping
[
structure
[
i
][
0
]].
append
(
i
)
stem_count
=
len
(
grouping
[
-
3
])
+
len
(
grouping
[
-
2
])
+
len
(
grouping
[
-
1
])
assert
stem_count
!=
0
stem_filters
=
128
//
stem_count
original_inputs
=
inputs
if
len
(
input_specs
.
shape
)
==
5
:
first_dim
=
(
input_specs
.
shape
[
0
]
*
input_specs
.
shape
[
1
]
if
input_specs
.
shape
[
0
]
and
input_specs
.
shape
[
1
]
else
-
1
)
reshape_inputs
=
tf
.
reshape
(
inputs
,
(
first_dim
,)
+
input_specs
.
shape
[
2
:])
elif
len
(
input_specs
.
shape
)
==
4
:
reshape_inputs
=
original_inputs
else
:
raise
ValueError
(
f
'Expect input spec to be 4 or 5 dimensions
{
input_specs
.
shape
}
'
)
if
grouping
[
-
2
]:
# Instead of loading optical flows as inputs from data pipeline, we are
# applying the "Representation Flow" to RGB frames so that we can compute
# the flow within TPU/GPU on fly. It's essentially optical flow since we
# do it with RGBs.
axis
=
3
if
data_format
==
'channels_last'
else
1
flow_inputs
=
rf
.
RepresentationFlow
(
original_num_frames
,
depth
=
reshape_inputs
.
shape
.
as_list
()[
axis
],
num_iter
=
40
,
bottleneck
=
1
)(
reshape_inputs
)
streams
=
[]
for
i
in
range
(
len
(
structure
)):
with
tf
.
name_scope
(
'Node_'
+
str
(
i
)):
if
structure
[
i
][
0
]
==
-
1
:
inputs
=
rgb_conv_stem
(
reshape_inputs
,
original_num_frames
,
stem_filters
,
temporal_dilation
=
structure
[
i
][
1
],
bn_decay
=
bn_decay
,
bn_epsilon
=
bn_epsilon
,
use_sync_bn
=
use_sync_bn
)
streams
.
append
(
inputs
)
elif
structure
[
i
][
0
]
==
-
2
:
inputs
=
flow_conv_stem
(
flow_inputs
,
stem_filters
,
temporal_dilation
=
structure
[
i
][
1
],
bn_decay
=
bn_decay
,
bn_epsilon
=
bn_epsilon
,
use_sync_bn
=
use_sync_bn
)
streams
.
append
(
inputs
)
else
:
num_frames
=
original_num_frames
block_number
=
structure
[
i
][
0
]
combined_inputs
=
[]
if
combine_method
==
'concat'
:
combined_inputs
=
[
streams
[
structure
[
i
][
1
][
j
]]
for
j
in
range
(
0
,
len
(
structure
[
i
][
1
]))
]
combined_inputs
=
spatial_resize_and_concat
(
combined_inputs
)
else
:
combined_inputs
=
[
streams
[
structure
[
i
][
1
][
j
]]
for
j
in
range
(
0
,
len
(
structure
[
i
][
1
]))
]
combined_inputs
=
multi_connection_fusion
(
combined_inputs
,
index
=
i
,
model_edge_weights
=
model_edge_weights
)
graph
=
block_group
(
inputs
=
combined_inputs
,
filters
=
structure
[
i
][
2
],
block_fn
=
block_fn
,
blocks
=
num_blocks
[
block_number
],
strides
=
structure
[
i
][
4
],
name
=
'block_group'
+
str
(
i
),
block_level
=
structure
[
i
][
0
],
num_frames
=
num_frames
,
temporal_dilation
=
structure
[
i
][
3
])
streams
.
append
(
graph
)
super
(
AssembleNet
,
self
).
__init__
(
inputs
=
original_inputs
,
outputs
=
streams
,
**
kwargs
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
AssembleNetModel
(
tf
.
keras
.
Model
):
"""An AssembleNet model builder."""
def
__init__
(
self
,
backbone
,
num_classes
,
num_frames
:
int
,
model_structure
:
List
[
Any
],
input_specs
:
Mapping
[
str
,
tf
.
keras
.
layers
.
InputSpec
]
=
None
,
max_pool_preditions
:
bool
=
False
,
**
kwargs
):
if
not
input_specs
:
input_specs
=
{
'image'
:
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
None
,
3
])
}
self
.
_self_setattr_tracking
=
False
self
.
_config_dict
=
{
'backbone'
:
backbone
,
'num_classes'
:
num_classes
,
'num_frames'
:
num_frames
,
'input_specs'
:
input_specs
,
'model_structure'
:
model_structure
,
}
self
.
_input_specs
=
input_specs
self
.
_backbone
=
backbone
grouping
=
{
-
3
:
[],
-
2
:
[],
-
1
:
[],
0
:
[],
1
:
[],
2
:
[],
3
:
[]}
for
i
in
range
(
len
(
model_structure
)):
grouping
[
model_structure
[
i
][
0
]].
append
(
i
)
inputs
=
{
k
:
tf
.
keras
.
Input
(
shape
=
v
.
shape
[
1
:])
for
k
,
v
in
input_specs
.
items
()
}
streams
=
self
.
_backbone
(
inputs
[
'image'
])
outputs
=
multi_stream_heads
(
streams
,
grouping
[
3
],
num_frames
,
num_classes
,
max_pool_preditions
=
max_pool_preditions
)
super
(
AssembleNetModel
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
outputs
,
**
kwargs
)
@
property
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
return
dict
(
backbone
=
self
.
backbone
)
@
property
def
backbone
(
self
):
return
self
.
_backbone
def
get_config
(
self
):
return
self
.
_config_dict
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
ASSEMBLENET_SPECS
=
{
26
:
{
'block'
:
bottleneck_block_interleave
,
'num_blocks'
:
[
2
,
2
,
2
,
2
]
},
38
:
{
'block'
:
bottleneck_block_interleave
,
'num_blocks'
:
[
2
,
4
,
4
,
2
]
},
50
:
{
'block'
:
bottleneck_block_interleave
,
'num_blocks'
:
[
3
,
4
,
6
,
3
]
},
68
:
{
'block'
:
bottleneck_block_interleave
,
'num_blocks'
:
[
3
,
4
,
12
,
3
]
},
77
:
{
'block'
:
bottleneck_block_interleave
,
'num_blocks'
:
[
3
,
4
,
15
,
3
]
},
101
:
{
'block'
:
bottleneck_block_interleave
,
'num_blocks'
:
[
3
,
4
,
23
,
3
]
},
}
def
assemblenet_v1
(
assemblenet_depth
:
int
,
num_classes
:
int
,
num_frames
:
int
,
model_structure
:
List
[
Any
],
input_specs
:
layers
.
InputSpec
=
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
None
,
3
]),
model_edge_weights
:
Optional
[
List
[
Any
]]
=
None
,
max_pool_preditions
:
bool
=
False
,
combine_method
:
str
=
'sigmoid'
,
**
kwargs
):
"""Returns the AssembleNet model for a given size and number of output classes."""
data_format
=
tf
.
keras
.
backend
.
image_data_format
()
assert
data_format
==
'channels_last'
if
assemblenet_depth
not
in
ASSEMBLENET_SPECS
:
raise
ValueError
(
'Not a valid assemblenet_depth:'
,
assemblenet_depth
)
input_specs_dict
=
{
'image'
:
input_specs
}
params
=
ASSEMBLENET_SPECS
[
assemblenet_depth
]
backbone
=
AssembleNet
(
block_fn
=
params
[
'block'
],
num_blocks
=
params
[
'num_blocks'
],
num_frames
=
num_frames
,
model_structure
=
model_structure
,
input_specs
=
input_specs
,
model_edge_weights
=
model_edge_weights
,
combine_method
=
combine_method
,
**
kwargs
)
return
AssembleNetModel
(
backbone
,
num_classes
=
num_classes
,
num_frames
=
num_frames
,
model_structure
=
model_structure
,
input_specs
=
input_specs_dict
,
max_pool_preditions
=
max_pool_preditions
,
**
kwargs
)
@
backbone_factory
.
register_backbone_builder
(
'assemblenet'
)
def
build_assemblenet_v1
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
:
cfg
.
Backbone3D
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds assemblenet backbone."""
del
l2_regularizer
backbone_type
=
model_config
.
backbone
.
type
backbone_cfg
=
model_config
.
backbone
.
get
()
norm_activation_config
=
model_config
.
norm_activation
assert
backbone_type
==
'assemblenet'
assemblenet_depth
=
int
(
backbone_cfg
.
model_id
)
if
assemblenet_depth
not
in
ASSEMBLENET_SPECS
:
raise
ValueError
(
'Not a valid assemblenet_depth:'
,
assemblenet_depth
)
model_structure
,
model_edge_weights
=
cfg
.
blocks_to_flat_lists
(
backbone_cfg
.
blocks
)
params
=
ASSEMBLENET_SPECS
[
assemblenet_depth
]
block_fn
=
functools
.
partial
(
params
[
'block'
],
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
bn_decay
=
norm_activation_config
.
norm_momentum
,
bn_epsilon
=
norm_activation_config
.
norm_epsilon
)
backbone
=
AssembleNet
(
block_fn
=
block_fn
,
num_blocks
=
params
[
'num_blocks'
],
num_frames
=
backbone_cfg
.
num_frames
,
model_structure
=
model_structure
,
input_specs
=
input_specs
,
model_edge_weights
=
model_edge_weights
,
combine_method
=
backbone_cfg
.
combine_method
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
bn_decay
=
norm_activation_config
.
norm_momentum
,
bn_epsilon
=
norm_activation_config
.
norm_epsilon
)
logging
.
info
(
'Number of parameters in AssembleNet backbone: %f M.'
,
backbone
.
count_params
()
/
10.
**
6
)
return
backbone
@
model_factory
.
register_model_builder
(
'assemblenet'
)
def
build_assemblenet_model
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
:
cfg
.
AssembleNetModel
,
num_classes
:
int
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
):
"""Builds assemblenet model."""
input_specs_dict
=
{
'image'
:
input_specs
}
backbone
=
build_assemblenet_v1
(
input_specs
,
model_config
,
l2_regularizer
)
backbone_cfg
=
model_config
.
backbone
.
get
()
model_structure
,
_
=
cfg
.
blocks_to_flat_lists
(
backbone_cfg
.
blocks
)
model
=
AssembleNetModel
(
backbone
,
num_classes
=
num_classes
,
num_frames
=
backbone_cfg
.
num_frames
,
model_structure
=
model_structure
,
input_specs
=
input_specs_dict
,
max_pool_preditions
=
model_config
.
max_pool_preditions
)
return
model
official/vision/beta/projects/assemblenet/modeling/rep_flow_2d_layer.py
0 → 100644
View file @
c609ff2e
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Contains definitions for 'Representation Flow' layer [1].
Representation flow layer is a generalization of optical flow extraction; the
layer could be inserted anywhere within a CNN to capture feature movements. This
is the version taking 4D tensor with the shape [batch*time, height, width,
channels], to make this run on TPU.
[1] AJ Piergiovanni and Michael S. Ryoo,
Representation Flow for Action Recognition. CVPR 2019.
"""
import
numpy
as
np
import
tensorflow
as
tf
layers
=
tf
.
keras
.
layers
BATCH_NORM_DECAY
=
0.99
BATCH_NORM_EPSILON
=
1e-5
def
build_batch_norm
(
init_zero
:
bool
=
False
,
bn_decay
:
float
=
BATCH_NORM_DECAY
,
bn_epsilon
:
float
=
BATCH_NORM_EPSILON
,
use_sync_bn
:
bool
=
False
):
"""Performs a batch normalization followed by a ReLU.
Args:
init_zero: `bool` if True, initializes scale parameter of batch
normalization with 0 instead of 1 (default).
bn_decay: `float` batch norm decay parameter to use.
bn_epsilon: `float` batch norm epsilon parameter to use.
use_sync_bn: use synchronized batch norm for TPU.
Returns:
A normalized `Tensor` with the same `data_format`.
"""
if
init_zero
:
gamma_initializer
=
tf
.
zeros_initializer
()
else
:
gamma_initializer
=
tf
.
ones_initializer
()
data_format
=
tf
.
keras
.
backend
.
image_data_format
()
assert
data_format
==
'channels_last'
if
data_format
==
'channels_first'
:
axis
=
1
else
:
axis
=
-
1
if
use_sync_bn
:
batch_norm
=
layers
.
experimental
.
SyncBatchNormalization
(
axis
=
axis
,
momentum
=
bn_decay
,
epsilon
=
bn_epsilon
,
gamma_initializer
=
gamma_initializer
)
else
:
batch_norm
=
layers
.
BatchNormalization
(
axis
=
axis
,
momentum
=
bn_decay
,
epsilon
=
bn_epsilon
,
fused
=
True
,
gamma_initializer
=
gamma_initializer
)
return
batch_norm
def
divergence
(
p1
,
p2
,
f_grad_x
,
f_grad_y
,
name
):
"""Computes the divergence value used with TV-L1 optical flow algorithm.
Args:
p1: 'Tensor' input.
p2: 'Tensor' input in the next frame.
f_grad_x: 'Tensor' x gradient of F value used in TV-L1.
f_grad_y: 'Tensor' y gradient of F value used in TV-L1.
name: 'str' name for the variable scope.
Returns:
A `Tensor` with the same `data_format` and shape as input.
"""
data_format
=
tf
.
keras
.
backend
.
image_data_format
()
df
=
'NHWC'
if
data_format
==
'channels_last'
else
'NCHW'
with
tf
.
name_scope
(
'divergence_'
+
name
):
if
data_format
==
'channels_last'
:
p1
=
tf
.
pad
(
p1
[:,
:,
:
-
1
,
:],
[[
0
,
0
],
[
0
,
0
],
[
1
,
0
],
[
0
,
0
]])
p2
=
tf
.
pad
(
p2
[:,
:
-
1
,
:,
:],
[[
0
,
0
],
[
1
,
0
],
[
0
,
0
],
[
0
,
0
]])
else
:
p1
=
tf
.
pad
(
p1
[:,
:,
:,
:
-
1
],
[[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
1
,
0
]])
p2
=
tf
.
pad
(
p2
[:,
:,
:
-
1
,
:],
[[
0
,
0
],
[
0
,
0
],
[
1
,
0
],
[
0
,
0
]])
grad_x
=
tf
.
nn
.
conv2d
(
p1
,
f_grad_x
,
[
1
,
1
,
1
,
1
],
'SAME'
,
data_format
=
df
)
grad_y
=
tf
.
nn
.
conv2d
(
p2
,
f_grad_y
,
[
1
,
1
,
1
,
1
],
'SAME'
,
data_format
=
df
)
return
grad_x
+
grad_y
def
forward_grad
(
x
,
f_grad_x
,
f_grad_y
,
name
):
data_format
=
tf
.
keras
.
backend
.
image_data_format
()
with
tf
.
name_scope
(
'forward_grad_'
+
name
):
df
=
'NHWC'
if
data_format
==
'channels_last'
else
'NCHW'
grad_x
=
tf
.
nn
.
conv2d
(
x
,
f_grad_x
,
[
1
,
1
,
1
,
1
],
'SAME'
,
data_format
=
df
)
grad_y
=
tf
.
nn
.
conv2d
(
x
,
f_grad_y
,
[
1
,
1
,
1
,
1
],
'SAME'
,
data_format
=
df
)
return
grad_x
,
grad_y
def
norm_img
(
x
):
mx
=
tf
.
reduce_max
(
x
)
mn
=
tf
.
reduce_min
(
x
)
if
mx
==
mn
:
return
x
else
:
return
255
*
(
x
-
mn
)
/
(
mx
-
mn
)
class
RepresentationFlow
(
layers
.
Layer
):
"""Computes the representation flow motivated by TV-L1 optical flow."""
def
__init__
(
self
,
time
:
int
,
depth
:
int
,
num_iter
:
int
=
20
,
bottleneck
:
int
=
32
,
train_feature_grad
:
bool
=
False
,
train_divergence
:
bool
=
False
,
train_flow_grad
:
bool
=
False
,
train_hyper
:
bool
=
False
,
**
kwargs
):
"""Constructor.
Args:
time: 'int' number of frames in the input tensor.
depth: channel depth of the input tensor.
num_iter: 'int' number of iterations to use for the flow computation.
bottleneck: 'int' number of filters to be used for the flow computation.
train_feature_grad: Train image grad params.
train_divergence: train divergence params
train_flow_grad: train flow grad params.
train_hyper: train rep flow hyperparams.
**kwargs: keyword arguments to be passed to the parent constructor.
Returns:
A `Tensor` with the same `data_format` and shape as input.
"""
super
(
RepresentationFlow
,
self
).
__init__
(
**
kwargs
)
self
.
_time
=
time
self
.
_depth
=
depth
self
.
_num_iter
=
num_iter
self
.
_bottleneck
=
bottleneck
self
.
_train_feature_grad
=
train_feature_grad
self
.
_train_divergence
=
train_divergence
self
.
_train_flow_grad
=
train_flow_grad
self
.
_train_hyper
=
train_hyper
def
get_config
(
self
):
config
=
{
'time'
:
self
.
_time
,
'num_iter'
:
self
.
_num_iter
,
'bottleneck'
:
self
.
_bottleneck
,
'train_feature_grad'
:
self
.
_train_feature_grad
,
'train_divergence'
:
self
.
_train_divergence
,
'train_flow_grad'
:
self
.
_train_flow_grad
,
'train_hyper'
:
self
.
_train_hyper
,
}
base_config
=
super
(
RepresentationFlow
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
build
(
self
,
input_shape
:
tf
.
TensorShape
):
img_grad
=
np
.
array
([
-
0.5
,
0
,
0.5
],
dtype
=
'float32'
)
img_grad_x
=
np
.
repeat
(
np
.
reshape
(
img_grad
,
(
1
,
3
,
1
,
1
)),
self
.
_bottleneck
,
axis
=
2
)
*
np
.
eye
(
self
.
_bottleneck
,
dtype
=
'float32'
)
self
.
img_grad_x
=
self
.
add_weight
(
shape
=
img_grad_x
.
shape
,
initializer
=
tf
.
constant_initializer
(
img_grad_x
),
trainable
=
self
.
_train_feature_grad
,
name
=
'img_grad_x'
)
img_grad_y
=
np
.
repeat
(
np
.
reshape
(
img_grad
,
(
3
,
1
,
1
,
1
)),
self
.
_bottleneck
,
axis
=
2
)
*
np
.
eye
(
self
.
_bottleneck
,
dtype
=
'float32'
)
self
.
img_grad_y
=
self
.
add_weight
(
shape
=
img_grad_y
.
shape
,
initializer
=
tf
.
constant_initializer
(
img_grad_y
),
trainable
=
self
.
_train_feature_grad
,
name
=
'img_grad_y'
)
f_grad
=
np
.
array
([
-
1
,
1
],
dtype
=
'float32'
)
f_grad_x
=
np
.
repeat
(
np
.
reshape
(
f_grad
,
(
1
,
2
,
1
,
1
)),
self
.
_bottleneck
,
axis
=
2
)
*
np
.
eye
(
self
.
_bottleneck
,
dtype
=
'float32'
)
self
.
f_grad_x
=
self
.
add_weight
(
shape
=
f_grad_x
.
shape
,
initializer
=
tf
.
constant_initializer
(
f_grad_x
),
trainable
=
self
.
_train_divergence
,
name
=
'f_grad_x'
)
f_grad_y
=
np
.
repeat
(
np
.
reshape
(
f_grad
,
(
2
,
1
,
1
,
1
)),
self
.
_bottleneck
,
axis
=
2
)
*
np
.
eye
(
self
.
_bottleneck
,
dtype
=
'float32'
)
self
.
f_grad_y
=
self
.
add_weight
(
shape
=
f_grad_y
.
shape
,
initializer
=
tf
.
constant_initializer
(
f_grad_y
),
trainable
=
self
.
_train_divergence
,
name
=
'f_grad_y'
)
f_grad_x2
=
np
.
repeat
(
np
.
reshape
(
f_grad
,
(
1
,
2
,
1
,
1
)),
self
.
_bottleneck
,
axis
=
2
)
*
np
.
eye
(
self
.
_bottleneck
,
dtype
=
'float32'
)
self
.
f_grad_x2
=
self
.
add_weight
(
shape
=
f_grad_x2
.
shape
,
initializer
=
tf
.
constant_initializer
(
f_grad_x2
),
trainable
=
self
.
_train_flow_grad
,
name
=
'f_grad_x2'
)
f_grad_y2
=
np
.
repeat
(
np
.
reshape
(
f_grad
,
(
2
,
1
,
1
,
1
)),
self
.
_bottleneck
,
axis
=
2
)
*
np
.
eye
(
self
.
_bottleneck
,
dtype
=
'float32'
)
self
.
f_grad_y2
=
self
.
add_weight
(
shape
=
f_grad_y2
.
shape
,
initializer
=
tf
.
constant_initializer
(
f_grad_y2
),
trainable
=
self
.
_train_flow_grad
,
name
=
'f_grad_y2'
)
self
.
t
=
self
.
add_weight
(
name
=
'theta'
,
initializer
=
tf
.
constant_initializer
(
0.3
),
trainable
=
self
.
_train_hyper
)
self
.
l
=
self
.
add_weight
(
name
=
'lambda'
,
initializer
=
tf
.
constant_initializer
(
0.15
),
trainable
=
self
.
_train_hyper
)
self
.
a
=
self
.
add_weight
(
name
=
'tau'
,
initializer
=
tf
.
constant_initializer
(
0.25
),
trainable
=
self
.
_train_hyper
)
self
.
t
=
tf
.
abs
(
self
.
t
)
+
1e-12
self
.
l_t
=
self
.
l
*
self
.
t
self
.
taut
=
self
.
a
/
self
.
t
self
.
_bottleneck_conv2
=
None
self
.
_bottleneck_conv2
=
None
if
self
.
_bottleneck
>
1
:
self
.
_bottleneck_conv1
=
layers
.
Conv2D
(
filters
=
self
.
_bottleneck
,
kernel_size
=
1
,
strides
=
1
,
padding
=
'same'
,
use_bias
=
False
,
kernel_initializer
=
tf
.
keras
.
initializers
.
VarianceScaling
(),
name
=
'rf/bottleneck1'
)
self
.
_bottleneck_conv2
=
layers
.
Conv2D
(
filters
=
self
.
_depth
,
kernel_size
=
1
,
strides
=
1
,
padding
=
'same'
,
use_bias
=
False
,
kernel_initializer
=
tf
.
keras
.
initializers
.
VarianceScaling
(),
name
=
'rf/bottleneck2'
)
self
.
_batch_norm
=
build_batch_norm
(
init_zero
=
True
)
def
call
(
self
,
inputs
:
tf
.
Tensor
,
training
:
bool
=
None
)
->
tf
.
Tensor
:
"""Perform representation flows.
Args:
inputs: list of `Tensors` of shape `[batch*time, height, width,
channels]`.
training: True for training phase.
Returns:
A tensor of the same shape as the inputs.
"""
data_format
=
tf
.
keras
.
backend
.
image_data_format
()
df
=
'NHWC'
if
data_format
==
'channels_last'
else
'NCHW'
axis
=
3
if
data_format
==
'channels_last'
else
1
# channel axis
dtype
=
inputs
.
dtype
residual
=
inputs
depth
=
inputs
.
shape
.
as_list
()[
axis
]
# assert depth == self._depth, f'rep_flow {depth} != {self._depth}'
if
self
.
_bottleneck
==
1
:
inputs
=
tf
.
reduce_mean
(
inputs
,
axis
=
axis
)
inputs
=
tf
.
expand_dims
(
inputs
,
-
1
)
elif
depth
!=
self
.
_bottleneck
:
inputs
=
self
.
_bottleneck_conv1
(
inputs
)
input_shape
=
inputs
.
shape
.
as_list
()
inp
=
norm_img
(
inputs
)
inp
=
tf
.
reshape
(
inp
,
(
-
1
,
self
.
_time
,
inputs
.
shape
[
1
],
inputs
.
shape
[
2
],
inputs
.
shape
[
3
]))
inp
=
tf
.
ensure_shape
(
inp
,
(
None
,
self
.
_time
,
input_shape
[
1
],
input_shape
[
2
],
input_shape
[
3
]))
img1
=
tf
.
reshape
(
inp
[:,
:
-
1
],
(
-
1
,
tf
.
shape
(
inp
)[
2
],
tf
.
shape
(
inp
)[
3
],
tf
.
shape
(
inp
)[
4
]))
img2
=
tf
.
reshape
(
inp
[:,
1
:],
(
-
1
,
tf
.
shape
(
inp
)[
2
],
tf
.
shape
(
inp
)[
3
],
tf
.
shape
(
inp
)[
4
]))
img1
=
tf
.
ensure_shape
(
img1
,
(
None
,
inputs
.
shape
[
1
],
inputs
.
shape
[
2
],
inputs
.
shape
[
3
]))
img2
=
tf
.
ensure_shape
(
img2
,
(
None
,
inputs
.
shape
[
1
],
inputs
.
shape
[
2
],
inputs
.
shape
[
3
]))
u1
=
tf
.
zeros_like
(
img1
,
dtype
=
dtype
)
u2
=
tf
.
zeros_like
(
img2
,
dtype
=
dtype
)
l_t
=
self
.
l_t
taut
=
self
.
taut
grad2_x
=
tf
.
nn
.
conv2d
(
img2
,
self
.
img_grad_x
,
[
1
,
1
,
1
,
1
],
'SAME'
,
data_format
=
df
)
grad2_y
=
tf
.
nn
.
conv2d
(
img2
,
self
.
img_grad_y
,
[
1
,
1
,
1
,
1
],
'SAME'
,
data_format
=
df
)
p11
=
tf
.
zeros_like
(
img1
,
dtype
=
dtype
)
p12
=
tf
.
zeros_like
(
img1
,
dtype
=
dtype
)
p21
=
tf
.
zeros_like
(
img1
,
dtype
=
dtype
)
p22
=
tf
.
zeros_like
(
img1
,
dtype
=
dtype
)
gsqx
=
grad2_x
**
2
gsqy
=
grad2_y
**
2
grad
=
gsqx
+
gsqy
+
1e-12
rho_c
=
img2
-
grad2_x
*
u1
-
grad2_y
*
u2
-
img1
for
_
in
range
(
self
.
_num_iter
):
rho
=
rho_c
+
grad2_x
*
u1
+
grad2_y
*
u2
+
1e-12
v1
=
tf
.
zeros_like
(
img1
,
dtype
=
dtype
)
v2
=
tf
.
zeros_like
(
img2
,
dtype
=
dtype
)
mask1
=
rho
<
-
l_t
*
grad
tmp11
=
tf
.
where
(
mask1
,
l_t
*
grad2_x
,
tf
.
zeros_like
(
grad2_x
,
dtype
=
dtype
))
tmp12
=
tf
.
where
(
mask1
,
l_t
*
grad2_y
,
tf
.
zeros_like
(
grad2_y
,
dtype
=
dtype
))
mask2
=
rho
>
l_t
*
grad
tmp21
=
tf
.
where
(
mask2
,
-
l_t
*
grad2_x
,
tf
.
zeros_like
(
grad2_x
,
dtype
=
dtype
))
tmp22
=
tf
.
where
(
mask2
,
-
l_t
*
grad2_y
,
tf
.
zeros_like
(
grad2_y
,
dtype
=
dtype
))
mask3
=
(
~
mask1
)
&
(
~
mask2
)
&
(
grad
>
1e-12
)
tmp31
=
tf
.
where
(
mask3
,
(
-
rho
/
grad
)
*
grad2_x
,
tf
.
zeros_like
(
grad2_x
,
dtype
=
dtype
))
tmp32
=
tf
.
where
(
mask3
,
(
-
rho
/
grad
)
*
grad2_y
,
tf
.
zeros_like
(
grad2_y
,
dtype
=
dtype
))
v1
=
tmp11
+
tmp21
+
tmp31
+
u1
v2
=
tmp12
+
tmp22
+
tmp32
+
u2
u1
=
v1
+
self
.
t
*
divergence
(
p11
,
p12
,
self
.
f_grad_x
,
self
.
f_grad_y
,
'div_p1'
)
u2
=
v2
+
self
.
t
*
divergence
(
p21
,
p22
,
self
.
f_grad_x
,
self
.
f_grad_y
,
'div_p2'
)
u1x
,
u1y
=
forward_grad
(
u1
,
self
.
f_grad_x2
,
self
.
f_grad_y2
,
'u1'
)
u2x
,
u2y
=
forward_grad
(
u2
,
self
.
f_grad_x2
,
self
.
f_grad_y2
,
'u2'
)
p11
=
(
p11
+
taut
*
u1x
)
/
(
1.
+
taut
*
tf
.
sqrt
(
u1x
**
2
+
u1y
**
2
+
1e-12
))
p12
=
(
p12
+
taut
*
u1y
)
/
(
1.
+
taut
*
tf
.
sqrt
(
u1x
**
2
+
u1y
**
2
+
1e-12
))
p21
=
(
p21
+
taut
*
u2x
)
/
(
1.
+
taut
*
tf
.
sqrt
(
u2x
**
2
+
u2y
**
2
+
1e-12
))
p22
=
(
p22
+
taut
*
u2y
)
/
(
1.
+
taut
*
tf
.
sqrt
(
u2x
**
2
+
u2y
**
2
+
1e-12
))
u1
=
tf
.
reshape
(
u1
,
(
-
1
,
self
.
_time
-
1
,
tf
.
shape
(
u1
)[
1
],
tf
.
shape
(
u1
)[
2
],
tf
.
shape
(
u1
)[
3
]))
u2
=
tf
.
reshape
(
u2
,
(
-
1
,
self
.
_time
-
1
,
tf
.
shape
(
u2
)[
1
],
tf
.
shape
(
u2
)[
2
],
tf
.
shape
(
u2
)[
3
]))
flow
=
tf
.
concat
([
u1
,
u2
],
axis
=
axis
+
1
)
flow
=
tf
.
concat
([
flow
,
tf
.
reshape
(
flow
[:,
-
1
,
:,
:,
:],
(
-
1
,
1
,
tf
.
shape
(
u1
)[
2
],
tf
.
shape
(
u1
)[
3
],
tf
.
shape
(
u1
)[
4
]
*
2
))
],
axis
=
1
)
# padding: [bs, 1, w, h, 2*c] -> [bs, 1, w, h, 2*c]
# flow is [bs, t, w, h, 2*c]
flow
=
tf
.
reshape
(
flow
,
(
-
1
,
tf
.
shape
(
u1
)[
2
],
tf
.
shape
(
u2
)[
3
],
tf
.
shape
(
u1
)[
4
]
*
2
))
# folwo is [bs*t, w, h, 2*c]
if
self
.
_bottleneck
==
1
:
output_shape
=
residual
.
shape
.
as_list
()
output_shape
[
-
1
]
=
self
.
_bottleneck
*
2
flow
=
tf
.
ensure_shape
(
flow
,
output_shape
)
return
flow
else
:
flow
=
self
.
_bottleneck_conv2
(
flow
)
flow
=
self
.
_batch_norm
(
flow
)
flow
=
tf
.
ensure_shape
(
flow
,
residual
.
shape
)
return
tf
.
nn
.
relu
(
flow
+
residual
)
official/vision/beta/projects/assemblenet/train.py
0 → 100644
View file @
c609ff2e
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Training driver."""
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
gin
# pylint: disable=unused-import
from
official.common
import
registry_imports
# pylint: enable=unused-import
from
official.common
import
distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.modeling
import
performance
# pylint: disable=unused-import
from
official.vision.beta.projects.assemblenet.configs
import
assemblenet
as
asn_configs
from
official.vision.beta.projects.assemblenet.modeling
import
assemblenet
as
asn
# pylint: enable=unused-import
FLAGS
=
flags
.
FLAGS
def
main
(
_
):
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_params
)
params
=
train_utils
.
parse_configuration
(
FLAGS
)
model_dir
=
FLAGS
.
model_dir
if
'train'
in
FLAGS
.
mode
:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils
.
serialize_config
(
params
,
model_dir
)
if
'train_and_eval'
in
FLAGS
.
mode
:
assert
(
params
.
task
.
train_data
.
feature_shape
==
params
.
task
.
validation_data
.
feature_shape
),
(
f
'train
{
params
.
task
.
train_data
.
feature_shape
}
!= validate '
f
'
{
params
.
task
.
validation_data
.
feature_shape
}
'
)
if
'assemblenet'
in
FLAGS
.
experiment
:
if
'eval'
in
FLAGS
.
mode
:
# Use the feature shape in validation_data for all jobs. The number of
# frames in train_data will be used to construct the Assemblenet model.
params
.
task
.
model
.
backbone
.
assemblenet
.
num_frames
=
params
.
task
.
validation_data
.
feature_shape
[
0
]
shape
=
params
.
task
.
validation_data
.
feature_shape
else
:
params
.
task
.
model
.
backbone
.
assemblenet
.
num_frames
=
params
.
task
.
train_data
.
feature_shape
[
0
]
shape
=
params
.
task
.
train_data
.
feature_shape
logging
.
info
(
'mode %r num_frames %r feature shape %r'
,
FLAGS
.
mode
,
params
.
task
.
model
.
backbone
.
assemblenet
.
num_frames
,
shape
)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
)
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
)
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
FLAGS
.
mode
,
params
=
params
,
model_dir
=
model_dir
)
train_utils
.
save_gin_config
(
FLAGS
.
mode
,
model_dir
)
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
app
.
run
(
main
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment