Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f8eb66ea
Commit
f8eb66ea
authored
Mar 22, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 364395434
parent
7a02b5ce
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
698 additions
and
0 deletions
+698
-0
official/vision/beta/configs/backbones.py
official/vision/beta/configs/backbones.py
+11
-0
official/vision/beta/configs/retinanet.py
official/vision/beta/configs/retinanet.py
+81
-0
official/vision/beta/modeling/backbones/__init__.py
official/vision/beta/modeling/backbones/__init__.py
+1
-0
official/vision/beta/modeling/backbones/spinenet_mobile.py
official/vision/beta/modeling/backbones/spinenet_mobile.py
+508
-0
official/vision/beta/modeling/backbones/spinenet_mobile_test.py
...al/vision/beta/modeling/backbones/spinenet_mobile_test.py
+97
-0
No files found.
official/vision/beta/configs/backbones.py
View file @
f8eb66ea
...
...
@@ -69,6 +69,15 @@ class SpineNet(hyperparams.Config):
stochastic_depth_drop_rate
:
float
=
0.0
@
dataclasses
.
dataclass
class
SpineNetMobile
(
hyperparams
.
Config
):
"""SpineNet config."""
model_id
:
str
=
'49'
stochastic_depth_drop_rate
:
float
=
0.0
se_ratio
:
float
=
0.2
expand_ratio
:
int
=
6
@
dataclasses
.
dataclass
class
RevNet
(
hyperparams
.
Config
):
"""RevNet config."""
...
...
@@ -87,6 +96,7 @@ class Backbone(hyperparams.OneOfConfig):
revnet: revnet backbone config.
efficientnet: efficientnet backbone config.
spinenet: spinenet backbone config.
spinenet_mobile: mobile spinenet backbone config.
mobilenet: mobilenet backbone config.
"""
type
:
Optional
[
str
]
=
None
...
...
@@ -95,4 +105,5 @@ class Backbone(hyperparams.OneOfConfig):
revnet
:
RevNet
=
RevNet
()
efficientnet
:
EfficientNet
=
EfficientNet
()
spinenet
:
SpineNet
=
SpineNet
()
spinenet_mobile
:
SpineNetMobile
=
SpineNetMobile
()
mobilenet
:
MobileNet
=
MobileNet
()
official/vision/beta/configs/retinanet.py
View file @
f8eb66ea
...
...
@@ -300,3 +300,84 @@ def retinanet_spinenet_coco() -> cfg.ExperimentConfig:
])
return
config
@
exp_factory
.
register_config_factory
(
'retinanet_spinenet_mobile_coco'
)
def
retinanet_spinenet_mobile_coco
()
->
cfg
.
ExperimentConfig
:
"""COCO object detection with RetinaNet using Mobile SpineNet backbone."""
train_batch_size
=
256
eval_batch_size
=
8
steps_per_epoch
=
COCO_TRAIN_EXAMPLES
//
train_batch_size
input_size
=
384
config
=
cfg
.
ExperimentConfig
(
runtime
=
cfg
.
RuntimeConfig
(
mixed_precision_dtype
=
'float32'
),
task
=
RetinaNetTask
(
annotation_file
=
os
.
path
.
join
(
COCO_INPUT_PATH_BASE
,
'instances_val2017.json'
),
model
=
RetinaNet
(
backbone
=
backbones
.
Backbone
(
type
=
'spinenet_mobile'
,
spinenet_mobile
=
backbones
.
SpineNetMobile
(
model_id
=
'49'
,
stochastic_depth_drop_rate
=
0.2
)),
decoder
=
decoders
.
Decoder
(
type
=
'identity'
,
identity
=
decoders
.
Identity
()),
anchor
=
Anchor
(
anchor_size
=
3
),
norm_activation
=
common
.
NormActivation
(
use_sync_bn
=
True
,
activation
=
'swish'
),
num_classes
=
91
,
input_size
=
[
input_size
,
input_size
,
3
],
min_level
=
3
,
max_level
=
7
),
losses
=
Losses
(
l2_weight_decay
=
4e-5
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
COCO_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
global_batch_size
=
train_batch_size
,
parser
=
Parser
(
aug_rand_hflip
=
True
,
aug_scale_min
=
0.1
,
aug_scale_max
=
2.0
)),
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
COCO_INPUT_PATH_BASE
,
'val*'
),
is_training
=
False
,
global_batch_size
=
eval_batch_size
)),
trainer
=
cfg
.
TrainerConfig
(
train_steps
=
600
*
steps_per_epoch
,
validation_steps
=
COCO_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
}
},
'learning_rate'
:
{
'type'
:
'stepwise'
,
'stepwise'
:
{
'boundaries'
:
[
575
*
steps_per_epoch
,
590
*
steps_per_epoch
],
'values'
:
[
0.32
*
train_batch_size
/
256.0
,
0.032
*
train_batch_size
/
256.0
,
0.0032
*
train_batch_size
/
256.0
],
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
2000
,
'warmup_learning_rate'
:
0.0067
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
official/vision/beta/modeling/backbones/__init__.py
View file @
f8eb66ea
...
...
@@ -22,3 +22,4 @@ from official.vision.beta.modeling.backbones.resnet_3d import ResNet3D
from
official.vision.beta.modeling.backbones.resnet_deeplab
import
DilatedResNet
from
official.vision.beta.modeling.backbones.revnet
import
RevNet
from
official.vision.beta.modeling.backbones.spinenet
import
SpineNet
from
official.vision.beta.modeling.backbones.spinenet_mobile
import
SpineNetMobile
official/vision/beta/modeling/backbones/spinenet_mobile.py
0 → 100644
View file @
f8eb66ea
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains definitions of Mobile SpineNet Networks."""
import
math
# Import libraries
from
absl
import
logging
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.modeling.layers
import
nn_blocks
from
official.vision.beta.modeling.layers
import
nn_layers
from
official.vision.beta.ops
import
spatial_transform_ops
layers
=
tf
.
keras
.
layers
FILTER_SIZE_MAP
=
{
0
:
8
,
1
:
16
,
2
:
24
,
3
:
40
,
4
:
80
,
5
:
112
,
6
:
112
,
7
:
112
,
}
# The fixed SpineNet architecture discovered by NAS.
# Each element represents a specification of a building block:
# (block_level, block_fn, (input_offset0, input_offset1), is_output).
SPINENET_BLOCK_SPECS
=
[
(
2
,
'mbconv'
,
(
0
,
1
),
False
),
(
2
,
'mbconv'
,
(
1
,
2
),
False
),
(
4
,
'mbconv'
,
(
1
,
2
),
False
),
(
3
,
'mbconv'
,
(
3
,
4
),
False
),
(
4
,
'mbconv'
,
(
3
,
5
),
False
),
(
6
,
'mbconv'
,
(
4
,
6
),
False
),
(
4
,
'mbconv'
,
(
4
,
6
),
False
),
(
5
,
'mbconv'
,
(
7
,
8
),
False
),
(
7
,
'mbconv'
,
(
7
,
9
),
False
),
(
5
,
'mbconv'
,
(
9
,
10
),
False
),
(
5
,
'mbconv'
,
(
9
,
11
),
False
),
(
4
,
'mbconv'
,
(
6
,
11
),
True
),
(
3
,
'mbconv'
,
(
5
,
11
),
True
),
(
5
,
'mbconv'
,
(
8
,
13
),
True
),
(
7
,
'mbconv'
,
(
6
,
15
),
True
),
(
6
,
'mbconv'
,
(
13
,
15
),
True
),
]
SCALING_MAP
=
{
'49'
:
{
'endpoints_num_filters'
:
48
,
'filter_size_scale'
:
1.0
,
'block_repeats'
:
1
,
},
'49S'
:
{
'endpoints_num_filters'
:
40
,
'filter_size_scale'
:
0.65
,
'block_repeats'
:
1
,
},
'49XS'
:
{
'endpoints_num_filters'
:
24
,
'filter_size_scale'
:
0.6
,
'block_repeats'
:
1
,
},
}
class
BlockSpec
(
object
):
"""A container class that specifies the block configuration for SpineNet."""
def
__init__
(
self
,
level
,
block_fn
,
input_offsets
,
is_output
):
self
.
level
=
level
self
.
block_fn
=
block_fn
self
.
input_offsets
=
input_offsets
self
.
is_output
=
is_output
def
build_block_specs
(
block_specs
=
None
):
"""Builds the list of BlockSpec objects for SpineNet."""
if
not
block_specs
:
block_specs
=
SPINENET_BLOCK_SPECS
logging
.
info
(
'Building SpineNet block specs: %s'
,
block_specs
)
return
[
BlockSpec
(
*
b
)
for
b
in
block_specs
]
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
SpineNetMobile
(
tf
.
keras
.
Model
):
"""Creates a Mobile SpineNet family model.
This implements:
[1] Xianzhi Du, Tsung-Yi Lin, Pengchong Jin, Golnaz Ghiasi, Mingxing Tan,
Yin Cui, Quoc V. Le, Xiaodan Song.
SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization.
(https://arxiv.org/abs/1912.05027).
[2] Xianzhi Du, Tsung-Yi Lin, Pengchong Jin, Yin Cui, Mingxing Tan,
Quoc Le, Xiaodan Song.
Efficient Scale-Permuted Backbone with Learned Resource Distribution.
(https://arxiv.org/abs/2010.11426).
"""
def
__init__
(
self
,
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
512
,
512
,
3
]),
min_level
=
3
,
max_level
=
7
,
block_specs
=
build_block_specs
(),
endpoints_num_filters
=
48
,
se_ratio
=
0.2
,
block_repeats
=
1
,
filter_size_scale
=
1.0
,
expand_ratio
=
6
,
init_stochastic_depth_rate
=
0.0
,
kernel_initializer
=
'VarianceScaling'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
activation
=
'relu'
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
**
kwargs
):
"""Initializes a Mobile SpineNet model.
Args:
input_specs: A `tf.keras.layers.InputSpec` of the input tensor.
min_level: An `int` of min level for output mutiscale features.
max_level: An `int` of max level for output mutiscale features.
block_specs: The block specifications for the SpineNet model discovered by
NAS.
endpoints_num_filters: An `int` of feature dimension for the output
endpoints.
se_ratio: A `float` of Squeeze-and-Excitation ratio.
block_repeats: An `int` of number of blocks contained in the layer.
filter_size_scale: A `float` of multiplier for the filters (number of
channels) for all convolution ops. The value must be greater than zero.
Typical usage will be to set this value in (0, 1) to reduce the number
of parameters or computation cost of the model.
expand_ratio: An `integer` of expansion ratios for inverted bottleneck
blocks.
init_stochastic_depth_rate: A `float` of initial stochastic depth rate.
kernel_initializer: A str for kernel initializer of convolutional layers.
kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
Conv2D. Default to None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
Default to None.
activation: A `str` name of the activation function.
use_sync_bn: If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A small `float` added to variance to avoid dividing by zero.
**kwargs: Additional keyword arguments to be passed.
"""
self
.
_input_specs
=
input_specs
self
.
_min_level
=
min_level
self
.
_max_level
=
max_level
self
.
_block_specs
=
block_specs
self
.
_endpoints_num_filters
=
endpoints_num_filters
self
.
_se_ratio
=
se_ratio
self
.
_block_repeats
=
block_repeats
self
.
_filter_size_scale
=
filter_size_scale
self
.
_expand_ratio
=
expand_ratio
self
.
_init_stochastic_depth_rate
=
init_stochastic_depth_rate
self
.
_kernel_initializer
=
kernel_initializer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_activation
=
activation
self
.
_use_sync_bn
=
use_sync_bn
self
.
_norm_momentum
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
if
activation
==
'relu'
:
self
.
_activation_fn
=
tf
.
nn
.
relu
elif
activation
==
'swish'
:
self
.
_activation_fn
=
tf
.
nn
.
swish
else
:
raise
ValueError
(
'Activation {} not implemented.'
.
format
(
activation
))
self
.
_num_init_blocks
=
2
if
use_sync_bn
:
self
.
_norm
=
layers
.
experimental
.
SyncBatchNormalization
else
:
self
.
_norm
=
layers
.
BatchNormalization
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
self
.
_bn_axis
=
-
1
else
:
self
.
_bn_axis
=
1
# Build SpineNet.
inputs
=
tf
.
keras
.
Input
(
shape
=
input_specs
.
shape
[
1
:])
net
=
self
.
_build_stem
(
inputs
=
inputs
)
net
=
self
.
_build_scale_permuted_network
(
net
=
net
,
input_width
=
input_specs
.
shape
[
2
])
endpoints
=
self
.
_build_endpoints
(
net
=
net
)
self
.
_output_specs
=
{
l
:
endpoints
[
l
].
get_shape
()
for
l
in
endpoints
}
super
().
__init__
(
inputs
=
inputs
,
outputs
=
endpoints
)
def
_block_group
(
self
,
inputs
,
in_filters
,
out_filters
,
strides
,
expand_ratio
=
6
,
block_repeats
=
1
,
se_ratio
=
0.2
,
stochastic_depth_drop_rate
=
None
,
name
=
'block_group'
):
"""Creates one group of blocks for the SpineNet model."""
x
=
nn_blocks
.
InvertedBottleneckBlock
(
in_filters
=
in_filters
,
out_filters
=
out_filters
,
strides
=
strides
,
se_ratio
=
se_ratio
,
expand_ratio
=
expand_ratio
,
stochastic_depth_drop_rate
=
stochastic_depth_drop_rate
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
self
.
_activation
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_momentum
,
norm_epsilon
=
self
.
_norm_epsilon
)(
inputs
)
for
_
in
range
(
1
,
block_repeats
):
x
=
nn_blocks
.
InvertedBottleneckBlock
(
in_filters
=
in_filters
,
out_filters
=
out_filters
,
strides
=
1
,
se_ratio
=
se_ratio
,
expand_ratio
=
expand_ratio
,
stochastic_depth_drop_rate
=
stochastic_depth_drop_rate
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
self
.
_activation
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_momentum
,
norm_epsilon
=
self
.
_norm_epsilon
)(
inputs
)
return
tf
.
identity
(
x
,
name
=
name
)
def
_build_stem
(
self
,
inputs
):
"""Builds SpineNet stem."""
x
=
layers
.
Conv2D
(
filters
=
int
(
FILTER_SIZE_MAP
[
0
]
*
self
.
_filter_size_scale
),
kernel_size
=
3
,
strides
=
2
,
use_bias
=
False
,
padding
=
'same'
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)(
inputs
)
x
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)(
x
)
x
=
tf_utils
.
get_activation
(
self
.
_activation_fn
)(
x
)
net
=
[]
stem_strides
=
[
1
,
2
]
# Build the initial level 2 blocks.
for
i
in
range
(
self
.
_num_init_blocks
):
x
=
self
.
_block_group
(
inputs
=
x
,
in_filters
=
int
(
FILTER_SIZE_MAP
[
i
]
*
self
.
_filter_size_scale
),
out_filters
=
int
(
FILTER_SIZE_MAP
[
i
+
1
]
*
self
.
_filter_size_scale
),
expand_ratio
=
self
.
_expand_ratio
,
strides
=
stem_strides
[
i
],
se_ratio
=
self
.
_se_ratio
,
block_repeats
=
self
.
_block_repeats
,
name
=
'stem_block_{}'
.
format
(
i
+
1
))
net
.
append
(
x
)
return
net
def
_build_scale_permuted_network
(
self
,
net
,
input_width
,
weighted_fusion
=
False
):
"""Builds scale-permuted network."""
net_sizes
=
[
int
(
math
.
ceil
(
input_width
/
2
)),
int
(
math
.
ceil
(
input_width
/
2
**
2
))
]
num_outgoing_connections
=
[
0
]
*
len
(
net
)
endpoints
=
{}
for
i
,
block_spec
in
enumerate
(
self
.
_block_specs
):
# Find out specs for the target block.
target_width
=
int
(
math
.
ceil
(
input_width
/
2
**
block_spec
.
level
))
target_num_filters
=
int
(
FILTER_SIZE_MAP
[
block_spec
.
level
]
*
self
.
_filter_size_scale
)
# Resample then merge input0 and input1.
parents
=
[]
input0
=
block_spec
.
input_offsets
[
0
]
input1
=
block_spec
.
input_offsets
[
1
]
x0
=
self
.
_resample_with_sepconv
(
inputs
=
net
[
input0
],
input_width
=
net_sizes
[
input0
],
target_width
=
target_width
,
target_num_filters
=
target_num_filters
)
parents
.
append
(
x0
)
num_outgoing_connections
[
input0
]
+=
1
x1
=
self
.
_resample_with_sepconv
(
inputs
=
net
[
input1
],
input_width
=
net_sizes
[
input1
],
target_width
=
target_width
,
target_num_filters
=
target_num_filters
)
parents
.
append
(
x1
)
num_outgoing_connections
[
input1
]
+=
1
# Merge 0 outdegree blocks to the output block.
if
block_spec
.
is_output
:
for
j
,
(
j_feat
,
j_connections
)
in
enumerate
(
zip
(
net
,
num_outgoing_connections
)):
if
j_connections
==
0
and
(
j_feat
.
shape
[
2
]
==
target_width
and
j_feat
.
shape
[
3
]
==
x0
.
shape
[
3
]):
parents
.
append
(
j_feat
)
num_outgoing_connections
[
j
]
+=
1
# pylint: disable=g-direct-tensorflow-import
if
weighted_fusion
:
dtype
=
parents
[
0
].
dtype
parent_weights
=
[
tf
.
nn
.
relu
(
tf
.
cast
(
tf
.
Variable
(
1.0
,
name
=
'block{}_fusion{}'
.
format
(
i
,
j
)),
dtype
=
dtype
))
for
j
in
range
(
len
(
parents
))]
weights_sum
=
tf
.
add_n
(
parent_weights
)
parents
=
[
parents
[
i
]
*
parent_weights
[
i
]
/
(
weights_sum
+
0.0001
)
for
i
in
range
(
len
(
parents
))
]
# Fuse all parent nodes then build a new block.
x
=
tf_utils
.
get_activation
(
self
.
_activation_fn
)(
tf
.
add_n
(
parents
))
x
=
self
.
_block_group
(
inputs
=
x
,
in_filters
=
target_num_filters
,
out_filters
=
target_num_filters
,
strides
=
1
,
se_ratio
=
self
.
_se_ratio
,
expand_ratio
=
self
.
_expand_ratio
,
block_repeats
=
self
.
_block_repeats
,
stochastic_depth_drop_rate
=
nn_layers
.
get_stochastic_depth_rate
(
self
.
_init_stochastic_depth_rate
,
i
+
1
,
len
(
self
.
_block_specs
)),
name
=
'scale_permuted_block_{}'
.
format
(
i
+
1
))
net
.
append
(
x
)
net_sizes
.
append
(
target_width
)
num_outgoing_connections
.
append
(
0
)
# Save output feats.
if
block_spec
.
is_output
:
if
block_spec
.
level
in
endpoints
:
raise
ValueError
(
'Duplicate feats found for output level {}.'
.
format
(
block_spec
.
level
))
if
(
block_spec
.
level
<
self
.
_min_level
or
block_spec
.
level
>
self
.
_max_level
):
raise
ValueError
(
'Output level is out of range [{}, {}]'
.
format
(
self
.
_min_level
,
self
.
_max_level
))
endpoints
[
str
(
block_spec
.
level
)]
=
x
return
endpoints
def
_build_endpoints
(
self
,
net
):
"""Matches filter size for endpoints before sharing conv layers."""
endpoints
=
{}
for
level
in
range
(
self
.
_min_level
,
self
.
_max_level
+
1
):
x
=
layers
.
Conv2D
(
filters
=
self
.
_endpoints_num_filters
,
kernel_size
=
1
,
strides
=
1
,
use_bias
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)(
net
[
str
(
level
)])
x
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)(
x
)
x
=
tf_utils
.
get_activation
(
self
.
_activation_fn
)(
x
)
endpoints
[
str
(
level
)]
=
x
return
endpoints
def
_resample_with_sepconv
(
self
,
inputs
,
input_width
,
target_width
,
target_num_filters
):
"""Matches resolution and feature dimension."""
x
=
inputs
# Spatial resampling.
if
input_width
>
target_width
:
while
input_width
>
target_width
:
x
=
layers
.
DepthwiseConv2D
(
kernel_size
=
3
,
strides
=
2
,
padding
=
'SAME'
,
use_bias
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)(
x
)
x
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)(
x
)
x
=
tf_utils
.
get_activation
(
self
.
_activation_fn
)(
x
)
input_width
/=
2
elif
input_width
<
target_width
:
scale
=
target_width
//
input_width
x
=
spatial_transform_ops
.
nearest_upsampling
(
x
,
scale
=
scale
)
# Last 1x1 conv to match filter size.
x
=
layers
.
Conv2D
(
filters
=
target_num_filters
,
kernel_size
=
1
,
strides
=
1
,
use_bias
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)(
x
)
x
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)(
x
)
return
x
def
get_config
(
self
):
config_dict
=
{
'min_level'
:
self
.
_min_level
,
'max_level'
:
self
.
_max_level
,
'endpoints_num_filters'
:
self
.
_endpoints_num_filters
,
'se_ratio'
:
self
.
_se_ratio
,
'expand_ratio'
:
self
.
_expand_ratio
,
'block_repeats'
:
self
.
_block_repeats
,
'filter_size_scale'
:
self
.
_filter_size_scale
,
'init_stochastic_depth_rate'
:
self
.
_init_stochastic_depth_rate
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'bias_regularizer'
:
self
.
_bias_regularizer
,
'activation'
:
self
.
_activation
,
'use_sync_bn'
:
self
.
_use_sync_bn
,
'norm_momentum'
:
self
.
_norm_momentum
,
'norm_epsilon'
:
self
.
_norm_epsilon
}
return
config_dict
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
@
property
def
output_specs
(
self
):
"""A dict of {level: TensorShape} pairs for the model output."""
return
self
.
_output_specs
@
factory
.
register_backbone_builder
(
'spinenet_mobile'
)
def
build_spinenet_mobile
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds Mobile SpineNet backbone from a config."""
backbone_type
=
model_config
.
backbone
.
type
backbone_cfg
=
model_config
.
backbone
.
get
()
norm_activation_config
=
model_config
.
norm_activation
assert
backbone_type
==
'spinenet_mobile'
,
(
f
'Inconsistent backbone type '
f
'
{
backbone_type
}
'
)
model_id
=
backbone_cfg
.
model_id
if
model_id
not
in
SCALING_MAP
:
raise
ValueError
(
'Mobile SpineNet-{} is not a valid architecture.'
.
format
(
model_id
))
scaling_params
=
SCALING_MAP
[
model_id
]
return
SpineNetMobile
(
input_specs
=
input_specs
,
min_level
=
model_config
.
min_level
,
max_level
=
model_config
.
max_level
,
endpoints_num_filters
=
scaling_params
[
'endpoints_num_filters'
],
block_repeats
=
scaling_params
[
'block_repeats'
],
filter_size_scale
=
scaling_params
[
'filter_size_scale'
],
se_ratio
=
backbone_cfg
.
se_ratio
,
expand_ratio
=
backbone_cfg
.
expand_ratio
,
init_stochastic_depth_rate
=
backbone_cfg
.
stochastic_depth_drop_rate
,
kernel_regularizer
=
l2_regularizer
,
activation
=
norm_activation_config
.
activation
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
)
official/vision/beta/modeling/backbones/spinenet_mobile_test.py
0 → 100644
View file @
f8eb66ea
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for SpineNet."""
# Import libraries
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.vision.beta.modeling.backbones
import
spinenet_mobile
class
SpineNetMobileTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
(
128
,
0.6
,
1
,
0.0
,
24
),
(
128
,
0.65
,
1
,
0.2
,
40
),
(
256
,
1.0
,
1
,
0.2
,
48
),
)
def
test_network_creation
(
self
,
input_size
,
filter_size_scale
,
block_repeats
,
se_ratio
,
endpoints_num_filters
):
"""Test creation of SpineNet models."""
min_level
=
3
max_level
=
7
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
input_size
,
input_size
,
3
])
model
=
spinenet_mobile
.
SpineNetMobile
(
input_specs
=
input_specs
,
min_level
=
min_level
,
max_level
=
max_level
,
endpoints_num_filters
=
endpoints_num_filters
,
resample_alpha
=
se_ratio
,
block_repeats
=
block_repeats
,
filter_size_scale
=
filter_size_scale
,
init_stochastic_depth_rate
=
0.2
,
)
inputs
=
tf
.
keras
.
Input
(
shape
=
(
input_size
,
input_size
,
3
),
batch_size
=
1
)
endpoints
=
model
(
inputs
)
for
l
in
range
(
min_level
,
max_level
+
1
):
self
.
assertIn
(
str
(
l
),
endpoints
.
keys
())
self
.
assertAllEqual
(
[
1
,
input_size
/
2
**
l
,
input_size
/
2
**
l
,
endpoints_num_filters
],
endpoints
[
str
(
l
)].
shape
.
as_list
())
def
test_serialize_deserialize
(
self
):
# Create a network object that sets all of its config options.
kwargs
=
dict
(
min_level
=
3
,
max_level
=
7
,
endpoints_num_filters
=
256
,
se_ratio
=
0.2
,
expand_ratio
=
6
,
block_repeats
=
1
,
filter_size_scale
=
1.0
,
init_stochastic_depth_rate
=
0.2
,
use_sync_bn
=
False
,
activation
=
'relu'
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
kernel_initializer
=
'VarianceScaling'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
)
network
=
spinenet_mobile
.
SpineNetMobile
(
**
kwargs
)
expected_config
=
dict
(
kwargs
)
self
.
assertEqual
(
network
.
get_config
(),
expected_config
)
# Create another network object from the first object's config.
new_network
=
spinenet_mobile
.
SpineNetMobile
.
from_config
(
network
.
get_config
())
# Validate that the config can be forced to JSON.
_
=
new_network
.
to_json
()
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
network
.
get_config
(),
new_network
.
get_config
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment