Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9c314a03
Commit
9c314a03
authored
Jul 17, 2020
by
A. Unique TensorFlower
Committed by
TF Object Detection Team
Jul 17, 2020
Browse files
Internal change
PiperOrigin-RevId: 321825286
parent
ec2d5d8d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
168 additions
and
1 deletion
+168
-1
research/object_detection/builders/model_builder.py
research/object_detection/builders/model_builder.py
+5
-1
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
...ction/models/center_net_mobilenet_v2_feature_extractor.py
+117
-0
research/object_detection/models/center_net_mobilenet_v2_feature_extractor_tf2_test.py
...els/center_net_mobilenet_v2_feature_extractor_tf2_test.py
+46
-0
No files found.
research/object_detection/builders/model_builder.py
View file @
9c314a03
...
@@ -48,6 +48,7 @@ from object_detection.utils import tf_version
...
@@ -48,6 +48,7 @@ from object_detection.utils import tf_version
# pylint: disable=g-import-not-at-top
# pylint: disable=g-import-not-at-top
if
tf_version
.
is_tf2
():
if
tf_version
.
is_tf2
():
from
object_detection.models
import
center_net_hourglass_feature_extractor
from
object_detection.models
import
center_net_hourglass_feature_extractor
from
object_detection.models
import
center_net_mobilenet_v2_feature_extractor
from
object_detection.models
import
center_net_resnet_feature_extractor
from
object_detection.models
import
center_net_resnet_feature_extractor
from
object_detection.models
import
center_net_resnet_v1_fpn_feature_extractor
from
object_detection.models
import
center_net_resnet_v1_fpn_feature_extractor
from
object_detection.models
import
faster_rcnn_inception_resnet_v2_keras_feature_extractor
as
frcnn_inc_res_keras
from
object_detection.models
import
faster_rcnn_inception_resnet_v2_keras_feature_extractor
as
frcnn_inc_res_keras
...
@@ -148,7 +149,10 @@ if tf_version.is_tf2():
...
@@ -148,7 +149,10 @@ if tf_version.is_tf2():
center_net_resnet_v1_fpn_feature_extractor
.
resnet_v1_50_fpn
,
center_net_resnet_v1_fpn_feature_extractor
.
resnet_v1_50_fpn
,
'resnet_v1_101_fpn'
:
'resnet_v1_101_fpn'
:
center_net_resnet_v1_fpn_feature_extractor
.
resnet_v1_101_fpn
,
center_net_resnet_v1_fpn_feature_extractor
.
resnet_v1_101_fpn
,
'hourglass_104'
:
center_net_hourglass_feature_extractor
.
hourglass_104
,
'hourglass_104'
:
center_net_hourglass_feature_extractor
.
hourglass_104
,
'mobilenet_v2'
:
center_net_mobilenet_v2_feature_extractor
.
mobilenet_v2
,
}
}
FEATURE_EXTRACTOR_MAPS
=
[
FEATURE_EXTRACTOR_MAPS
=
[
...
...
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
0 → 100644
View file @
9c314a03
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MobileNet V2[1] feature extractor for CenterNet[2] meta architecture.
[1]: https://arxiv.org/abs/1801.04381
[2]: https://arxiv.org/abs/1904.07850
"""
import
tensorflow.compat.v1
as
tf
from
object_detection.meta_architectures
import
center_net_meta_arch
from
object_detection.models.keras_models
import
mobilenet_v2
as
mobilenetv2
class
CenterNetMobileNetV2FeatureExtractor
(
center_net_meta_arch
.
CenterNetFeatureExtractor
):
"""The MobileNet V2 feature extractor for CenterNet."""
def
__init__
(
self
,
mobilenet_v2_net
,
channel_means
=
(
0.
,
0.
,
0.
),
channel_stds
=
(
1.
,
1.
,
1.
),
bgr_ordering
=
False
):
"""Intializes the feature extractor.
Args:
mobilenet_v2_net: The underlying mobilenet_v2 network to use.
channel_means: A tuple of floats, denoting the mean of each channel
which will be subtracted from it.
channel_stds: A tuple of floats, denoting the standard deviation of each
channel. Each channel will be divided by its standard deviation value.
bgr_ordering: bool, if set will change the channel ordering to be in the
[blue, red, green] order.
"""
super
(
CenterNetMobileNetV2FeatureExtractor
,
self
).
__init__
(
channel_means
=
channel_means
,
channel_stds
=
channel_stds
,
bgr_ordering
=
bgr_ordering
)
self
.
_network
=
mobilenet_v2_net
output
=
self
.
_network
(
self
.
_network
.
input
)
# TODO(nkhadke): Try out MobileNet+FPN next (skip connections are cheap and
# should help with performance).
# MobileNet by itself transforms a 224x224x3 volume into a 7x7x1280, which
# leads to a stride of 32. We perform upsampling to get it to a target
# stride of 4.
for
num_filters
in
[
256
,
128
,
64
]:
# 1. We use a simple convolution instead of a deformable convolution
conv
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
num_filters
,
kernel_size
=
1
,
strides
=
1
,
padding
=
'same'
)
output
=
conv
(
output
)
output
=
tf
.
keras
.
layers
.
BatchNormalization
()(
output
)
output
=
tf
.
keras
.
layers
.
ReLU
()(
output
)
# 2. We use the default initialization for the convolution layers
# instead of initializing it to do bilinear upsampling.
conv_transpose
=
tf
.
keras
.
layers
.
Conv2DTranspose
(
filters
=
num_filters
,
kernel_size
=
3
,
strides
=
2
,
padding
=
'same'
)
output
=
conv_transpose
(
output
)
output
=
tf
.
keras
.
layers
.
BatchNormalization
()(
output
)
output
=
tf
.
keras
.
layers
.
ReLU
()(
output
)
self
.
_network
=
tf
.
keras
.
models
.
Model
(
inputs
=
self
.
_network
.
input
,
outputs
=
output
)
def
preprocess
(
self
,
resized_inputs
):
resized_inputs
=
super
(
CenterNetMobileNetV2FeatureExtractor
,
self
).
preprocess
(
resized_inputs
)
return
tf
.
keras
.
applications
.
mobilenet_v2
.
preprocess_input
(
resized_inputs
)
def
load_feature_extractor_weights
(
self
,
path
):
self
.
_network
.
load_weights
(
path
)
def
get_base_model
(
self
):
return
self
.
_network
def
call
(
self
,
inputs
):
return
[
self
.
_network
(
inputs
)]
@
property
def
out_stride
(
self
):
"""The stride in the output image of the network."""
return
4
@
property
def
num_feature_outputs
(
self
):
"""The number of feature outputs returned by the feature extractor."""
return
1
def
get_model
(
self
):
return
self
.
_network
def
mobilenet_v2
(
channel_means
,
channel_stds
,
bgr_ordering
):
"""The MobileNetV2 backbone for CenterNet."""
# We set 'is_training' to True for now.
network
=
mobilenetv2
.
mobilenet_v2
(
True
,
include_top
=
False
)
return
CenterNetMobileNetV2FeatureExtractor
(
network
,
channel_means
=
channel_means
,
channel_stds
=
channel_stds
,
bgr_ordering
=
bgr_ordering
)
research/object_detection/models/center_net_mobilenet_v2_feature_extractor_tf2_test.py
0 → 100644
View file @
9c314a03
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Testing mobilenet_v2 feature extractor for CenterNet."""
import
unittest
import
numpy
as
np
import
tensorflow.compat.v1
as
tf
from
object_detection.models
import
center_net_mobilenet_v2_feature_extractor
from
object_detection.models.keras_models
import
mobilenet_v2
from
object_detection.utils
import
test_case
from
object_detection.utils
import
tf_version
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
class
CenterNetMobileNetV2FeatureExtractorTest
(
test_case
.
TestCase
):
def
test_center_net_mobilenet_v2_feature_extractor
(
self
):
net
=
mobilenet_v2
.
mobilenet_v2
(
True
,
include_top
=
False
)
model
=
center_net_mobilenet_v2_feature_extractor
.
CenterNetMobileNetV2FeatureExtractor
(
net
)
def
graph_fn
():
img
=
np
.
zeros
((
8
,
224
,
224
,
3
),
dtype
=
np
.
float32
)
processed_img
=
model
.
preprocess
(
img
)
return
model
(
processed_img
)
outputs
=
self
.
execute
(
graph_fn
,
[])
self
.
assertEqual
(
outputs
.
shape
,
(
8
,
56
,
56
,
64
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment