Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9725a407
Commit
9725a407
authored
Jun 29, 2020
by
A. Unique TensorFlower
Browse files
Introducing SpineNet backbone to TF2.
PiperOrigin-RevId: 318945656
parent
91e2171b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
888 additions
and
2 deletions
+888
-2
official/README.md
official/README.md
+4
-2
official/benchmark/retinanet_benchmark.py
official/benchmark/retinanet_benchmark.py
+17
-0
official/vision/detection/README.md
official/vision/detection/README.md
+34
-0
official/vision/detection/configs/base_config.py
official/vision/detection/configs/base_config.py
+5
-0
official/vision/detection/modeling/architecture/factory.py
official/vision/detection/modeling/architecture/factory.py
+4
-0
official/vision/detection/modeling/architecture/nn_blocks.py
official/vision/detection/modeling/architecture/nn_blocks.py
+318
-0
official/vision/detection/modeling/architecture/spinenet.py
official/vision/detection/modeling/architecture/spinenet.py
+506
-0
No files found.
official/README.md
View file @
9725a407
...
@@ -19,9 +19,10 @@ In the near future, we will add:
...
@@ -19,9 +19,10 @@ In the near future, we will add:
*
State-of-the-art language understanding models:
*
State-of-the-art language understanding models:
More members in Transformer family
More members in Transformer family
*
Sta
r
t-of-the-art image classification models:
*
Stat
e
-of-the-art image classification models:
EfficientNet, MnasNet, and variants
EfficientNet, MnasNet, and variants
*
A set of excellent objection detection models.
*
State-of-the-art objection detection and instance segmentation models:
RetinaNet, Mask R-CNN, SpineNet, and variants
## Table of Contents
## Table of Contents
...
@@ -52,6 +53,7 @@ In the near future, we will add:
...
@@ -52,6 +53,7 @@ In the near future, we will add:
|
[
RetinaNet
](
vision/detection
)
|
[
Focal Loss for Dense Object Detection
](
https://arxiv.org/abs/1708.02002
)
|
|
[
RetinaNet
](
vision/detection
)
|
[
Focal Loss for Dense Object Detection
](
https://arxiv.org/abs/1708.02002
)
|
|
[
Mask R-CNN
](
vision/detection
)
|
[
Mask R-CNN
](
https://arxiv.org/abs/1703.06870
)
|
|
[
Mask R-CNN
](
vision/detection
)
|
[
Mask R-CNN
](
https://arxiv.org/abs/1703.06870
)
|
|
[
ShapeMask
](
vision/detection
)
|
[
ShapeMask: Learning to Segment Novel Objects by Refining Shape Priors
](
https://arxiv.org/abs/1904.03239
)
|
|
[
ShapeMask
](
vision/detection
)
|
[
ShapeMask: Learning to Segment Novel Objects by Refining Shape Priors
](
https://arxiv.org/abs/1904.03239
)
|
|
[
SpineNet
](
vision/detection
)
|
[
SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization
](
https://arxiv.org/abs/1912.05027
)
|
### Natural Language Processing
### Natural Language Processing
...
...
official/benchmark/retinanet_benchmark.py
View file @
9725a407
...
@@ -271,6 +271,23 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
...
@@ -271,6 +271,23 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
FLAGS
.
strategy_type
=
'tpu'
FLAGS
.
strategy_type
=
'tpu'
self
.
_run_and_report_benchmark
(
params
,
do_eval
=
False
,
warmup
=
0
)
self
.
_run_and_report_benchmark
(
params
,
do_eval
=
False
,
warmup
=
0
)
@
flagsaver
.
flagsaver
def
benchmark_2x2_tpu_spinenet_coco
(
self
):
"""Run SpineNet with RetinaNet model accuracy test with 4 TPUs."""
self
.
_setup
()
params
=
self
.
_params
()
params
[
'architecture'
][
'backbone'
]
=
'spinenet'
params
[
'architecture'
][
'multilevel_features'
]
=
'identity'
params
[
'architecture'
][
'use_bfloat16'
]
=
False
params
[
'train'
][
'batch_size'
]
=
64
params
[
'train'
][
'total_steps'
]
=
1875
# One epoch.
params
[
'train'
][
'iterations_per_loop'
]
=
500
params
[
'train'
][
'checkpoint'
][
'path'
]
=
''
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'real_benchmark_2x2_tpu_spinenet_coco'
)
FLAGS
.
strategy_type
=
'tpu'
self
.
_run_and_report_benchmark
(
params
,
do_eval
=
False
,
warmup
=
0
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/vision/detection/README.md
View file @
9725a407
...
@@ -48,6 +48,22 @@ so the checkpoints are not compatible.
...
@@ -48,6 +48,22 @@ so the checkpoints are not compatible.
We will unify the implementation soon.
We will unify the implementation soon.
### Train a SpineNet-49 based RetinaNet.
```
bash
TPU_NAME
=
"<your GCP TPU name>"
MODEL_DIR
=
"<path to the directory to store model files>"
TRAIN_FILE_PATTERN
=
"<path to the TFRecord training data>"
EVAL_FILE_PATTERN
=
"<path to the TFRecord validation data>"
VAL_JSON_FILE
=
"<path to the validation annotation JSON file>"
python3 ~/models/official/vision/detection/main.py
\
--strategy_type
=
tpu
\
--tpu
=
"
${
TPU_NAME
?
}
"
\
--model_dir
=
"
${
MODEL_DIR
?
}
"
\
--mode
=
train
\
--params_override
=
"{ type: retinanet, architecture: {backbone: spinenet, multilevel_features: identity}, spinenet: {model_id: 49}, train_file_pattern:
${
TRAIN_FILE_PATTERN
?
}
}, eval: { val_json_file:
${
VAL_JSON_FILE
?
}
, eval_file_pattern:
${
EVAL_FILE_PATTERN
?
}
} }"
```
### Train a custom RetinaNet using the config file.
### Train a custom RetinaNet using the config file.
...
@@ -163,6 +179,24 @@ so the checkpoints are not compatible.
...
@@ -163,6 +179,24 @@ so the checkpoints are not compatible.
We will unify the implementation soon.
We will unify the implementation soon.
### Train a SpineNet-49 based Mask R-CNN.
```
bash
TPU_NAME
=
"<your GCP TPU name>"
MODEL_DIR
=
"<path to the directory to store model files>"
TRAIN_FILE_PATTERN
=
"<path to the TFRecord training data>"
EVAL_FILE_PATTERN
=
"<path to the TFRecord validation data>"
VAL_JSON_FILE
=
"<path to the validation annotation JSON file>"
python3 ~/models/official/vision/detection/main.py
\
--strategy_type
=
tpu
\
--tpu
=
"
${
TPU_NAME
?
}
"
\
--model_dir
=
"
${
MODEL_DIR
?
}
"
\
--mode
=
train
\
--model
=
mask_rcnn
\
--params_override
=
"{architecture: {backbone: spinenet, multilevel_features: identity}, spinenet: {model_id: 49}, train_file_pattern:
${
TRAIN_FILE_PATTERN
?
}
}, eval: { val_json_file:
${
VAL_JSON_FILE
?
}
, eval_file_pattern:
${
EVAL_FILE_PATTERN
?
}
} }"
```
### Train a custom Mask R-CNN using the config file.
### Train a custom Mask R-CNN using the config file.
First, create a YAML config file, e.g.
*my_maskrcnn.yaml*
.
First, create a YAML config file, e.g.
*my_maskrcnn.yaml*
.
...
...
official/vision/detection/configs/base_config.py
View file @
9725a407
...
@@ -17,10 +17,12 @@
...
@@ -17,10 +17,12 @@
BACKBONES
=
[
BACKBONES
=
[
'resnet'
,
'resnet'
,
'spinenet'
,
]
]
MULTILEVEL_FEATURES
=
[
MULTILEVEL_FEATURES
=
[
'fpn'
,
'fpn'
,
'identity'
,
]
]
# pylint: disable=line-too-long
# pylint: disable=line-too-long
...
@@ -118,6 +120,9 @@ BASE_CFG = {
...
@@ -118,6 +120,9 @@ BASE_CFG = {
'resnet'
:
{
'resnet'
:
{
'resnet_depth'
:
50
,
'resnet_depth'
:
50
,
},
},
'spinenet'
:
{
'model_id'
:
'49'
,
},
'fpn'
:
{
'fpn'
:
{
'fpn_feat_dims'
:
256
,
'fpn_feat_dims'
:
256
,
'use_separable_conv'
:
False
,
'use_separable_conv'
:
False
,
...
...
official/vision/detection/modeling/architecture/factory.py
View file @
9725a407
...
@@ -23,6 +23,7 @@ from official.vision.detection.modeling.architecture import heads
...
@@ -23,6 +23,7 @@ from official.vision.detection.modeling.architecture import heads
from
official.vision.detection.modeling.architecture
import
identity
from
official.vision.detection.modeling.architecture
import
identity
from
official.vision.detection.modeling.architecture
import
nn_ops
from
official.vision.detection.modeling.architecture
import
nn_ops
from
official.vision.detection.modeling.architecture
import
resnet
from
official.vision.detection.modeling.architecture
import
resnet
from
official.vision.detection.modeling.architecture
import
spinenet
def
norm_activation_generator
(
params
):
def
norm_activation_generator
(
params
):
...
@@ -42,6 +43,9 @@ def backbone_generator(params):
...
@@ -42,6 +43,9 @@ def backbone_generator(params):
activation
=
params
.
norm_activation
.
activation
,
activation
=
params
.
norm_activation
.
activation
,
norm_activation
=
norm_activation_generator
(
norm_activation
=
norm_activation_generator
(
params
.
norm_activation
))
params
.
norm_activation
))
elif
params
.
architecture
.
backbone
==
'spinenet'
:
spinenet_params
=
params
.
spinenet
backbone_fn
=
spinenet
.
SpineNetBuilder
(
model_id
=
spinenet_params
.
model_id
)
else
:
else
:
raise
ValueError
(
'Backbone model `{}` is not supported.'
raise
ValueError
(
'Backbone model `{}` is not supported.'
.
format
(
params
.
architecture
.
backbone
))
.
format
(
params
.
architecture
.
backbone
))
...
...
official/vision/detection/modeling/architecture/nn_blocks.py
0 → 100644
View file @
9725a407
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains common building blocks for neural networks."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
ResidualBlock
(
tf
.
keras
.
layers
.
Layer
):
"""A residual block."""
def
__init__
(
self
,
filters
,
strides
,
use_projection
=
False
,
kernel_initializer
=
'VarianceScaling'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
activation
=
'relu'
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
**
kwargs
):
"""A residual block with BN after convolutions.
Args:
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
use_projection: `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super
(
ResidualBlock
,
self
).
__init__
(
**
kwargs
)
self
.
_filters
=
filters
self
.
_strides
=
strides
self
.
_use_projection
=
use_projection
self
.
_use_sync_bn
=
use_sync_bn
self
.
_activation
=
activation
self
.
_kernel_initializer
=
kernel_initializer
self
.
_norm_momentum
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
if
use_sync_bn
:
self
.
_norm
=
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
else
:
self
.
_norm
=
tf
.
keras
.
layers
.
BatchNormalization
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
self
.
_bn_axis
=
-
1
else
:
self
.
_bn_axis
=
1
self
.
_activation_fn
=
tf_utils
.
get_activation
(
activation
)
def
build
(
self
,
input_shape
):
if
self
.
_use_projection
:
self
.
_shortcut
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
self
.
_filters
,
kernel_size
=
1
,
strides
=
self
.
_strides
,
use_bias
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)
self
.
_norm0
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)
self
.
_conv1
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
self
.
_filters
,
kernel_size
=
3
,
strides
=
self
.
_strides
,
padding
=
'same'
,
use_bias
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)
self
.
_norm1
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)
self
.
_conv2
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
self
.
_filters
,
kernel_size
=
3
,
strides
=
1
,
padding
=
'same'
,
use_bias
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)
self
.
_norm2
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)
super
(
ResidualBlock
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
config
=
{
'filters'
:
self
.
_filters
,
'strides'
:
self
.
_strides
,
'use_projection'
:
self
.
_use_projection
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'bias_regularizer'
:
self
.
_bias_regularizer
,
'activation'
:
self
.
_activation
,
'use_sync_bn'
:
self
.
_use_sync_bn
,
'norm_momentum'
:
self
.
_norm_momentum
,
'norm_epsilon'
:
self
.
_norm_epsilon
}
base_config
=
super
(
ResidualBlock
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
):
shortcut
=
inputs
if
self
.
_use_projection
:
shortcut
=
self
.
_shortcut
(
shortcut
)
shortcut
=
self
.
_norm0
(
shortcut
)
x
=
self
.
_conv1
(
inputs
)
x
=
self
.
_norm1
(
x
)
x
=
self
.
_activation_fn
(
x
)
x
=
self
.
_conv2
(
x
)
x
=
self
.
_norm2
(
x
)
return
self
.
_activation_fn
(
x
+
shortcut
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
BottleneckBlock
(
tf
.
keras
.
layers
.
Layer
):
"""A standard bottleneck block."""
def
__init__
(
self
,
filters
,
strides
,
use_projection
=
False
,
kernel_initializer
=
'VarianceScaling'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
activation
=
'relu'
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
**
kwargs
):
"""A standard bottleneck block with BN after convolutions.
Args:
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
use_projection: `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super
(
BottleneckBlock
,
self
).
__init__
(
**
kwargs
)
self
.
_filters
=
filters
self
.
_strides
=
strides
self
.
_use_projection
=
use_projection
self
.
_use_sync_bn
=
use_sync_bn
self
.
_activation
=
activation
self
.
_kernel_initializer
=
kernel_initializer
self
.
_norm_momentum
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
if
use_sync_bn
:
self
.
_norm
=
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
else
:
self
.
_norm
=
tf
.
keras
.
layers
.
BatchNormalization
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
self
.
_bn_axis
=
-
1
else
:
self
.
_bn_axis
=
1
self
.
_activation_fn
=
tf_utils
.
get_activation
(
activation
)
def
build
(
self
,
input_shape
):
if
self
.
_use_projection
:
self
.
_shortcut
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
self
.
_filters
*
4
,
kernel_size
=
1
,
strides
=
self
.
_strides
,
use_bias
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)
self
.
_norm0
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)
self
.
_conv1
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
self
.
_filters
,
kernel_size
=
1
,
strides
=
1
,
use_bias
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)
self
.
_norm1
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)
self
.
_conv2
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
self
.
_filters
,
kernel_size
=
3
,
strides
=
self
.
_strides
,
padding
=
'same'
,
use_bias
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)
self
.
_norm2
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)
self
.
_conv3
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
self
.
_filters
*
4
,
kernel_size
=
1
,
strides
=
1
,
use_bias
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)
self
.
_norm3
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)
super
(
BottleneckBlock
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
config
=
{
'filters'
:
self
.
_filters
,
'strides'
:
self
.
_strides
,
'use_projection'
:
self
.
_use_projection
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'bias_regularizer'
:
self
.
_bias_regularizer
,
'activation'
:
self
.
_activation
,
'use_sync_bn'
:
self
.
_use_sync_bn
,
'norm_momentum'
:
self
.
_norm_momentum
,
'norm_epsilon'
:
self
.
_norm_epsilon
}
base_config
=
super
(
BottleneckBlock
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
):
shortcut
=
inputs
if
self
.
_use_projection
:
shortcut
=
self
.
_shortcut
(
shortcut
)
shortcut
=
self
.
_norm0
(
shortcut
)
x
=
self
.
_conv1
(
inputs
)
x
=
self
.
_norm1
(
x
)
x
=
self
.
_activation_fn
(
x
)
x
=
self
.
_conv2
(
x
)
x
=
self
.
_norm2
(
x
)
x
=
self
.
_activation_fn
(
x
)
x
=
self
.
_conv3
(
x
)
x
=
self
.
_norm3
(
x
)
return
self
.
_activation_fn
(
x
+
shortcut
)
official/vision/detection/modeling/architecture/spinenet.py
0 → 100644
View file @
9725a407
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of SpineNet model.
X. Du, T-Y. Lin, P. Jin, G. Ghiasi, M. Tan, Y. Cui, Q. V. Le, X. Song
SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization
https://arxiv.org/abs/1912.05027
"""
import
math
from
absl
import
logging
import
tensorflow
as
tf
from
tensorflow.python.keras
import
backend
from
official.modeling
import
tf_utils
from
official.vision.detection.modeling.architecture
import
nn_blocks
layers
=
tf
.
keras
.
layers
FILTER_SIZE_MAP
=
{
1
:
32
,
2
:
64
,
3
:
128
,
4
:
256
,
5
:
256
,
6
:
256
,
7
:
256
,
}
# The fixed SpineNet architecture discovered by NAS.
# Each element represents a specification of a building block:
# (block_level, block_fn, (input_offset0, input_offset1), is_output).
SPINENET_BLOCK_SPECS
=
[
(
2
,
'bottleneck'
,
(
0
,
1
),
False
),
(
4
,
'residual'
,
(
0
,
1
),
False
),
(
3
,
'bottleneck'
,
(
2
,
3
),
False
),
(
4
,
'bottleneck'
,
(
2
,
4
),
False
),
(
6
,
'residual'
,
(
3
,
5
),
False
),
(
4
,
'bottleneck'
,
(
3
,
5
),
False
),
(
5
,
'residual'
,
(
6
,
7
),
False
),
(
7
,
'residual'
,
(
6
,
8
),
False
),
(
5
,
'bottleneck'
,
(
8
,
9
),
False
),
(
5
,
'bottleneck'
,
(
8
,
10
),
False
),
(
4
,
'bottleneck'
,
(
5
,
10
),
True
),
(
3
,
'bottleneck'
,
(
4
,
10
),
True
),
(
5
,
'bottleneck'
,
(
7
,
12
),
True
),
(
7
,
'bottleneck'
,
(
5
,
14
),
True
),
(
6
,
'bottleneck'
,
(
12
,
14
),
True
),
]
SCALING_MAP
=
{
'49S'
:
{
'endpoints_num_filters'
:
128
,
'filter_size_scale'
:
0.65
,
'resample_alpha'
:
0.5
,
'block_repeats'
:
1
,
},
'49'
:
{
'endpoints_num_filters'
:
256
,
'filter_size_scale'
:
1.0
,
'resample_alpha'
:
0.5
,
'block_repeats'
:
1
,
},
'96'
:
{
'endpoints_num_filters'
:
256
,
'filter_size_scale'
:
1.0
,
'resample_alpha'
:
0.5
,
'block_repeats'
:
2
,
},
'143'
:
{
'endpoints_num_filters'
:
256
,
'filter_size_scale'
:
1.0
,
'resample_alpha'
:
1.0
,
'block_repeats'
:
3
,
},
'190'
:
{
'endpoints_num_filters'
:
512
,
'filter_size_scale'
:
1.3
,
'resample_alpha'
:
1.0
,
'block_repeats'
:
4
,
},
}
class
BlockSpec
(
object
):
"""A container class that specifies the block configuration for SpineNet."""
def
__init__
(
self
,
level
,
block_fn
,
input_offsets
,
is_output
):
self
.
level
=
level
self
.
block_fn
=
block_fn
self
.
input_offsets
=
input_offsets
self
.
is_output
=
is_output
def
build_block_specs
(
block_specs
=
None
):
"""Builds the list of BlockSpec objects for SpineNet."""
if
not
block_specs
:
block_specs
=
SPINENET_BLOCK_SPECS
logging
.
info
(
'Building SpineNet block specs: %s'
,
block_specs
)
return
[
BlockSpec
(
*
b
)
for
b
in
block_specs
]
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
SpineNet
(
tf
.
keras
.
Model
):
"""Class to build SpineNet models."""
def
__init__
(
self
,
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
640
,
640
,
3
]),
min_level
=
3
,
max_level
=
7
,
block_specs
=
build_block_specs
(),
endpoints_num_filters
=
256
,
resample_alpha
=
0.5
,
block_repeats
=
1
,
filter_size_scale
=
1.0
,
kernel_initializer
=
'VarianceScaling'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
activation
=
'relu'
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
**
kwargs
):
"""SpineNet model."""
self
.
_min_level
=
min_level
self
.
_max_level
=
max_level
self
.
_block_specs
=
block_specs
self
.
_endpoints_num_filters
=
endpoints_num_filters
self
.
_resample_alpha
=
resample_alpha
self
.
_block_repeats
=
block_repeats
self
.
_filter_size_scale
=
filter_size_scale
self
.
_kernel_initializer
=
kernel_initializer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_use_sync_bn
=
use_sync_bn
self
.
_norm_momentum
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
if
activation
==
'relu'
:
self
.
_activation
=
tf
.
nn
.
relu
elif
activation
==
'swish'
:
self
.
_activation
=
tf
.
nn
.
swish
else
:
raise
ValueError
(
'Activation {} not implemented.'
.
format
(
activation
))
self
.
_init_block_fn
=
'bottleneck'
self
.
_num_init_blocks
=
2
if
use_sync_bn
:
self
.
_norm
=
layers
.
experimental
.
SyncBatchNormalization
else
:
self
.
_norm
=
layers
.
BatchNormalization
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
self
.
_bn_axis
=
-
1
else
:
self
.
_bn_axis
=
1
# Build SpineNet.
inputs
=
tf
.
keras
.
Input
(
shape
=
input_specs
.
shape
[
1
:])
net
=
self
.
_build_stem
(
inputs
=
inputs
)
net
=
self
.
_build_scale_permuted_network
(
net
=
net
,
input_width
=
input_specs
.
shape
[
1
])
net
=
self
.
_build_endpoints
(
net
=
net
)
super
(
SpineNet
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
net
)
def
_block_group
(
self
,
inputs
,
filters
,
strides
,
block_fn_cand
,
block_repeats
=
1
,
name
=
'block_group'
):
"""Creates one group of blocks for the SpineNet model."""
block_fn_candidates
=
{
'bottleneck'
:
nn_blocks
.
BottleneckBlock
,
'residual'
:
nn_blocks
.
ResidualBlock
,
}
block_fn
=
block_fn_candidates
[
block_fn_cand
]
_
,
_
,
_
,
num_filters
=
inputs
.
get_shape
().
as_list
()
if
block_fn_cand
==
'bottleneck'
:
use_projection
=
not
(
num_filters
==
(
filters
*
4
)
and
strides
==
1
)
else
:
use_projection
=
not
(
num_filters
==
filters
and
strides
==
1
)
x
=
block_fn
(
filters
=
filters
,
strides
=
strides
,
use_projection
=
use_projection
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
self
.
_activation
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_momentum
,
norm_epsilon
=
self
.
_norm_epsilon
)(
inputs
)
for
_
in
range
(
1
,
block_repeats
):
x
=
block_fn
(
filters
=
filters
,
strides
=
1
,
use_projection
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
self
.
_activation
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_momentum
,
norm_epsilon
=
self
.
_norm_epsilon
)(
x
)
return
tf
.
identity
(
x
,
name
=
name
)
def
_build_stem
(
self
,
inputs
):
"""Build SpineNet stem."""
x
=
layers
.
Conv2D
(
filters
=
64
,
kernel_size
=
7
,
strides
=
2
,
use_bias
=
False
,
padding
=
'same'
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)(
inputs
)
x
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)(
x
)
x
=
tf_utils
.
get_activation
(
self
.
_activation
)(
x
)
x
=
layers
.
MaxPool2D
(
pool_size
=
3
,
strides
=
2
,
padding
=
'same'
)(
x
)
net
=
[]
# Build the initial level 2 blocks.
for
i
in
range
(
self
.
_num_init_blocks
):
x
=
self
.
_block_group
(
inputs
=
x
,
filters
=
int
(
FILTER_SIZE_MAP
[
2
]
*
self
.
_filter_size_scale
),
strides
=
1
,
block_fn_cand
=
self
.
_init_block_fn
,
block_repeats
=
self
.
_block_repeats
,
name
=
'stem_block_{}'
.
format
(
i
+
1
))
net
.
append
(
x
)
return
net
def
_build_scale_permuted_network
(
self
,
net
,
input_width
,
weighted_fusion
=
False
):
"""Build scale-permuted network."""
net_sizes
=
[
int
(
math
.
ceil
(
input_width
/
2
**
2
))]
*
len
(
net
)
net_block_fns
=
[
self
.
_init_block_fn
]
*
len
(
net
)
num_outgoing_connections
=
[
0
]
*
len
(
net
)
endpoints
=
{}
for
i
,
block_spec
in
enumerate
(
self
.
_block_specs
):
# Find out specs for the target block.
target_width
=
int
(
math
.
ceil
(
input_width
/
2
**
block_spec
.
level
))
target_num_filters
=
int
(
FILTER_SIZE_MAP
[
block_spec
.
level
]
*
self
.
_filter_size_scale
)
target_block_fn
=
block_spec
.
block_fn
# Resample then merge input0 and input1.
parents
=
[]
input0
=
block_spec
.
input_offsets
[
0
]
input1
=
block_spec
.
input_offsets
[
1
]
x0
=
self
.
_resample_with_alpha
(
inputs
=
net
[
input0
],
input_width
=
net_sizes
[
input0
],
input_block_fn
=
net_block_fns
[
input0
],
target_width
=
target_width
,
target_num_filters
=
target_num_filters
,
target_block_fn
=
target_block_fn
,
alpha
=
self
.
_resample_alpha
)
parents
.
append
(
x0
)
num_outgoing_connections
[
input0
]
+=
1
x1
=
self
.
_resample_with_alpha
(
inputs
=
net
[
input1
],
input_width
=
net_sizes
[
input1
],
input_block_fn
=
net_block_fns
[
input1
],
target_width
=
target_width
,
target_num_filters
=
target_num_filters
,
target_block_fn
=
target_block_fn
,
alpha
=
self
.
_resample_alpha
)
parents
.
append
(
x1
)
num_outgoing_connections
[
input1
]
+=
1
# Merge 0 outdegree blocks to the output block.
if
block_spec
.
is_output
:
for
j
,
(
j_feat
,
j_connections
)
in
enumerate
(
zip
(
net
,
num_outgoing_connections
)):
if
j_connections
==
0
and
(
j_feat
.
shape
[
2
]
==
target_width
and
j_feat
.
shape
[
3
]
==
x0
.
shape
[
3
]):
parents
.
append
(
j_feat
)
num_outgoing_connections
[
j
]
+=
1
# pylint: disable=g-direct-tensorflow-import
if
weighted_fusion
:
dtype
=
parents
[
0
].
dtype
parent_weights
=
[
tf
.
nn
.
relu
(
tf
.
cast
(
tf
.
Variable
(
1.0
,
name
=
'block{}_fusion{}'
.
format
(
i
,
j
)),
dtype
=
dtype
))
for
j
in
range
(
len
(
parents
))]
weights_sum
=
tf
.
add_n
(
parent_weights
)
parents
=
[
parents
[
i
]
*
parent_weights
[
i
]
/
(
weights_sum
+
0.0001
)
for
i
in
range
(
len
(
parents
))
]
# Fuse all parent nodes then build a new block.
x
=
tf_utils
.
get_activation
(
self
.
_activation
)(
tf
.
add_n
(
parents
))
x
=
self
.
_block_group
(
inputs
=
x
,
filters
=
target_num_filters
,
strides
=
1
,
block_fn_cand
=
target_block_fn
,
block_repeats
=
self
.
_block_repeats
,
name
=
'scale_permuted_block_{}'
.
format
(
i
+
1
))
net
.
append
(
x
)
net_sizes
.
append
(
target_width
)
net_block_fns
.
append
(
target_block_fn
)
num_outgoing_connections
.
append
(
0
)
# Save output feats.
if
block_spec
.
is_output
:
if
block_spec
.
level
in
endpoints
:
raise
ValueError
(
'Duplicate feats found for output level {}.'
.
format
(
block_spec
.
level
))
if
(
block_spec
.
level
<
self
.
_min_level
or
block_spec
.
level
>
self
.
_max_level
):
raise
ValueError
(
'Output level is out of range [{}, {}]'
.
format
(
self
.
_min_level
,
self
.
_max_level
))
endpoints
[
block_spec
.
level
]
=
x
return
endpoints
def
_build_endpoints
(
self
,
net
):
"""Match filter size for endpoints before sharing conv layers."""
endpoints
=
{}
for
level
in
range
(
self
.
_min_level
,
self
.
_max_level
+
1
):
x
=
layers
.
Conv2D
(
filters
=
self
.
_endpoints_num_filters
,
kernel_size
=
1
,
strides
=
1
,
use_bias
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)(
net
[
level
])
x
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)(
x
)
x
=
tf_utils
.
get_activation
(
self
.
_activation
)(
x
)
endpoints
[
level
]
=
x
return
endpoints
def
_resample_with_alpha
(
self
,
inputs
,
input_width
,
input_block_fn
,
target_width
,
target_num_filters
,
target_block_fn
,
alpha
=
0.5
):
"""Match resolution and feature dimension."""
_
,
_
,
_
,
input_num_filters
=
inputs
.
get_shape
().
as_list
()
if
input_block_fn
==
'bottleneck'
:
input_num_filters
/=
4
new_num_filters
=
int
(
input_num_filters
*
alpha
)
x
=
layers
.
Conv2D
(
filters
=
new_num_filters
,
kernel_size
=
1
,
strides
=
1
,
use_bias
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)(
inputs
)
x
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)(
x
)
x
=
tf_utils
.
get_activation
(
self
.
_activation
)(
x
)
# Spatial resampling.
if
input_width
>
target_width
:
x
=
layers
.
Conv2D
(
filters
=
new_num_filters
,
kernel_size
=
3
,
strides
=
2
,
padding
=
'SAME'
,
use_bias
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)(
x
)
x
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)(
x
)
x
=
tf_utils
.
get_activation
(
self
.
_activation
)(
x
)
input_width
/=
2
while
input_width
>
target_width
:
x
=
layers
.
MaxPool2D
(
pool_size
=
3
,
strides
=
2
,
padding
=
'SAME'
)(
x
)
input_width
/=
2
elif
input_width
<
target_width
:
scale
=
target_width
//
input_width
x
=
layers
.
UpSampling2D
(
size
=
(
scale
,
scale
))(
x
)
# Last 1x1 conv to match filter size.
if
target_block_fn
==
'bottleneck'
:
target_num_filters
*=
4
x
=
layers
.
Conv2D
(
filters
=
target_num_filters
,
kernel_size
=
1
,
strides
=
1
,
use_bias
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)(
x
)
x
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)(
x
)
return
x
class
SpineNetBuilder
(
object
):
"""SpineNet builder."""
def
__init__
(
self
,
model_id
,
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
640
,
640
,
3
]),
min_level
=
3
,
max_level
=
7
,
block_specs
=
build_block_specs
(),
kernel_initializer
=
'VarianceScaling'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
activation
=
'relu'
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
):
if
model_id
not
in
SCALING_MAP
:
raise
ValueError
(
'SpineNet {} is not a valid architecture.'
.
format
(
model_id
))
scaling_params
=
SCALING_MAP
[
model_id
]
self
.
_input_specs
=
input_specs
self
.
_min_level
=
min_level
self
.
_max_level
=
max_level
self
.
_block_specs
=
block_specs
self
.
_endpoints_num_filters
=
scaling_params
[
'endpoints_num_filters'
]
self
.
_resample_alpha
=
scaling_params
[
'resample_alpha'
]
self
.
_block_repeats
=
scaling_params
[
'block_repeats'
]
self
.
_filter_size_scale
=
scaling_params
[
'filter_size_scale'
]
self
.
_kernel_initializer
=
kernel_initializer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_activation
=
activation
self
.
_use_sync_bn
=
use_sync_bn
self
.
_norm_momentum
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
def
__call__
(
self
,
inputs
,
is_training
=
None
):
with
backend
.
get_graph
().
as_default
():
model
=
SpineNet
(
input_specs
=
self
.
_input_specs
,
min_level
=
self
.
_min_level
,
max_level
=
self
.
_max_level
,
block_specs
=
self
.
_block_specs
,
endpoints_num_filters
=
self
.
_endpoints_num_filters
,
resample_alpha
=
self
.
_resample_alpha
,
block_repeats
=
self
.
_block_repeats
,
filter_size_scale
=
self
.
_filter_size_scale
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
self
.
_activation
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_momentum
,
norm_epsilon
=
self
.
_norm_epsilon
)
return
model
(
inputs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment