Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
04c56a03
Commit
04c56a03
authored
Sep 15, 2020
by
A. Unique TensorFlower
Committed by
TF Object Detection Team
Sep 15, 2020
Browse files
Introduce a MobileNetV2+FPN backbone for CenterNet.
PiperOrigin-RevId: 331796097
parent
63f56ffd
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
195 additions
and
4 deletions
+195
-4
research/object_detection/builders/model_builder.py
research/object_detection/builders/model_builder.py
+7
-2
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
...ction/models/center_net_mobilenet_v2_feature_extractor.py
+0
-2
research/object_detection/models/center_net_mobilenet_v2_fpn_feature_extractor.py
...n/models/center_net_mobilenet_v2_fpn_feature_extractor.py
+142
-0
research/object_detection/models/center_net_mobilenet_v2_fpn_feature_extractor_tf2_test.py
...center_net_mobilenet_v2_fpn_feature_extractor_tf2_test.py
+46
-0
No files found.
research/object_detection/builders/model_builder.py
View file @
04c56a03
...
...
@@ -50,6 +50,7 @@ from object_detection.utils import tf_version
if
tf_version
.
is_tf2
():
from
object_detection.models
import
center_net_hourglass_feature_extractor
from
object_detection.models
import
center_net_mobilenet_v2_feature_extractor
from
object_detection.models
import
center_net_mobilenet_v2_fpn_feature_extractor
from
object_detection.models
import
center_net_resnet_feature_extractor
from
object_detection.models
import
center_net_resnet_v1_fpn_feature_extractor
from
object_detection.models
import
faster_rcnn_inception_resnet_v2_keras_feature_extractor
as
frcnn_inc_res_keras
...
...
@@ -140,8 +141,10 @@ if tf_version.is_tf2():
}
CENTER_NET_EXTRACTOR_FUNCTION_MAP
=
{
'resnet_v2_50'
:
center_net_resnet_feature_extractor
.
resnet_v2_50
,
'resnet_v2_101'
:
center_net_resnet_feature_extractor
.
resnet_v2_101
,
'resnet_v2_50'
:
center_net_resnet_feature_extractor
.
resnet_v2_50
,
'resnet_v2_101'
:
center_net_resnet_feature_extractor
.
resnet_v2_101
,
'resnet_v1_18_fpn'
:
center_net_resnet_v1_fpn_feature_extractor
.
resnet_v1_18_fpn
,
'resnet_v1_34_fpn'
:
...
...
@@ -154,6 +157,8 @@ if tf_version.is_tf2():
center_net_hourglass_feature_extractor
.
hourglass_104
,
'mobilenet_v2'
:
center_net_mobilenet_v2_feature_extractor
.
mobilenet_v2
,
'mobilenet_v2_fpn'
:
center_net_mobilenet_v2_fpn_feature_extractor
.
mobilenet_v2_fpn
,
}
FEATURE_EXTRACTOR_MAPS
=
[
...
...
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
View file @
04c56a03
...
...
@@ -53,8 +53,6 @@ class CenterNetMobileNetV2FeatureExtractor(
output
=
self
.
_network
(
self
.
_network
.
input
)
# TODO(nkhadke): Try out MobileNet+FPN next (skip connections are cheap and
# should help with performance).
# MobileNet by itself transforms a 224x224x3 volume into a 7x7x1280, which
# leads to a stride of 32. We perform upsampling to get it to a target
# stride of 4.
...
...
research/object_detection/models/center_net_mobilenet_v2_fpn_feature_extractor.py
0 → 100644
View file @
04c56a03
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MobileNet V2[1] + FPN[2] feature extractor for CenterNet[3] meta architecture.
[1]: https://arxiv.org/abs/1801.04381
[2]: https://arxiv.org/abs/1612.03144.
[3]: https://arxiv.org/abs/1904.07850
"""
import
tensorflow.compat.v1
as
tf
from
object_detection.meta_architectures
import
center_net_meta_arch
from
object_detection.models.keras_models
import
mobilenet_v2
as
mobilenetv2
_MOBILENET_V2_FPN_SKIP_LAYERS
=
[
'block_2_add'
,
'block_5_add'
,
'block_9_add'
,
'out_relu'
]
class
CenterNetMobileNetV2FPNFeatureExtractor
(
center_net_meta_arch
.
CenterNetFeatureExtractor
):
"""The MobileNet V2 with FPN skip layers feature extractor for CenterNet."""
def
__init__
(
self
,
mobilenet_v2_net
,
channel_means
=
(
0.
,
0.
,
0.
),
channel_stds
=
(
1.
,
1.
,
1.
),
bgr_ordering
=
False
):
"""Intializes the feature extractor.
Args:
mobilenet_v2_net: The underlying mobilenet_v2 network to use.
channel_means: A tuple of floats, denoting the mean of each channel
which will be subtracted from it.
channel_stds: A tuple of floats, denoting the standard deviation of each
channel. Each channel will be divided by its standard deviation value.
bgr_ordering: bool, if set will change the channel ordering to be in the
[blue, red, green] order.
"""
super
(
CenterNetMobileNetV2FPNFeatureExtractor
,
self
).
__init__
(
channel_means
=
channel_means
,
channel_stds
=
channel_stds
,
bgr_ordering
=
bgr_ordering
)
self
.
_network
=
mobilenet_v2_net
output
=
self
.
_network
(
self
.
_network
.
input
)
# Add pyramid feature network on every layer that has stride 2.
skip_outputs
=
[
self
.
_network
.
get_layer
(
skip_layer_name
).
output
for
skip_layer_name
in
_MOBILENET_V2_FPN_SKIP_LAYERS
]
self
.
_fpn_model
=
tf
.
keras
.
models
.
Model
(
inputs
=
self
.
_network
.
input
,
outputs
=
skip_outputs
)
fpn_outputs
=
self
.
_fpn_model
(
self
.
_network
.
input
)
# Construct the top-down feature maps -- we start with an output of
# 7x7x1280, which we continually upsample, apply a residual on and merge.
# This results in a 56x56x24 output volume.
top_layer
=
fpn_outputs
[
-
1
]
residual_op
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
64
,
kernel_size
=
1
,
strides
=
1
,
padding
=
'same'
)
top_down
=
residual_op
(
top_layer
)
num_filters_list
=
[
64
,
32
,
24
]
for
i
,
num_filters
in
enumerate
(
num_filters_list
):
level_ind
=
len
(
num_filters_list
)
-
1
-
i
# Upsample.
upsample_op
=
tf
.
keras
.
layers
.
UpSampling2D
(
2
,
interpolation
=
'nearest'
)
top_down
=
upsample_op
(
top_down
)
# Residual (skip-connection) from bottom-up pathway.
residual_op
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
num_filters
,
kernel_size
=
1
,
strides
=
1
,
padding
=
'same'
)
residual
=
residual_op
(
fpn_outputs
[
level_ind
])
# Merge.
top_down
=
top_down
+
residual
next_num_filters
=
num_filters_list
[
i
+
1
]
if
i
+
1
<=
2
else
24
conv
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
next_num_filters
,
kernel_size
=
3
,
strides
=
1
,
padding
=
'same'
)
top_down
=
conv
(
top_down
)
top_down
=
tf
.
keras
.
layers
.
BatchNormalization
()(
top_down
)
top_down
=
tf
.
keras
.
layers
.
ReLU
()(
top_down
)
output
=
top_down
self
.
_network
=
tf
.
keras
.
models
.
Model
(
inputs
=
self
.
_network
.
input
,
outputs
=
output
)
def
preprocess
(
self
,
resized_inputs
):
resized_inputs
=
super
(
CenterNetMobileNetV2FPNFeatureExtractor
,
self
).
preprocess
(
resized_inputs
)
return
tf
.
keras
.
applications
.
mobilenet_v2
.
preprocess_input
(
resized_inputs
)
def
load_feature_extractor_weights
(
self
,
path
):
self
.
_network
.
load_weights
(
path
)
def
get_base_model
(
self
):
return
self
.
_network
def
call
(
self
,
inputs
):
return
[
self
.
_network
(
inputs
)]
@
property
def
out_stride
(
self
):
"""The stride in the output image of the network."""
return
4
@
property
def
num_feature_outputs
(
self
):
"""The number of feature outputs returned by the feature extractor."""
return
1
def
get_model
(
self
):
return
self
.
_network
def
mobilenet_v2_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
):
"""The MobileNetV2+FPN backbone for CenterNet."""
# Set to is_training to True for now.
network
=
mobilenetv2
.
mobilenet_v2
(
True
,
include_top
=
False
)
return
CenterNetMobileNetV2FPNFeatureExtractor
(
network
,
channel_means
=
channel_means
,
channel_stds
=
channel_stds
,
bgr_ordering
=
bgr_ordering
)
research/object_detection/models/center_net_mobilenet_v2_fpn_feature_extractor_tf2_test.py
0 → 100644
View file @
04c56a03
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Testing mobilenet_v2+FPN feature extractor for CenterNet."""
import
unittest
import
numpy
as
np
import
tensorflow.compat.v1
as
tf
from
object_detection.models
import
center_net_mobilenet_v2_fpn_feature_extractor
from
object_detection.models.keras_models
import
mobilenet_v2
from
object_detection.utils
import
test_case
from
object_detection.utils
import
tf_version
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
class
CenterNetMobileNetV2FPNFeatureExtractorTest
(
test_case
.
TestCase
):
def
test_center_net_mobilenet_v2_fpn_feature_extractor
(
self
):
net
=
mobilenet_v2
.
mobilenet_v2
(
True
,
include_top
=
False
)
model
=
center_net_mobilenet_v2_fpn_feature_extractor
.
CenterNetMobileNetV2FPNFeatureExtractor
(
net
)
def
graph_fn
():
img
=
np
.
zeros
((
8
,
224
,
224
,
3
),
dtype
=
np
.
float32
)
processed_img
=
model
.
preprocess
(
img
)
return
model
(
processed_img
)
outputs
=
self
.
execute
(
graph_fn
,
[])
self
.
assertEqual
(
outputs
.
shape
,
(
8
,
56
,
56
,
24
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment