Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
72d37b69
Commit
72d37b69
authored
Mar 31, 2021
by
Ronny Votel
Committed by
TF Object Detection Team
Mar 31, 2021
Browse files
Updating CenterNet Feature Extractors to plumb through model specific parameters.
PiperOrigin-RevId: 366134194
parent
cee4b75e
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
92 additions
and
47 deletions
+92
-47
research/object_detection/builders/model_builder.py
research/object_detection/builders/model_builder.py
+2
-3
research/object_detection/models/center_net_hourglass_feature_extractor.py
...etection/models/center_net_hourglass_feature_extractor.py
+10
-5
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
...ction/models/center_net_mobilenet_v2_feature_extractor.py
+8
-2
research/object_detection/models/center_net_mobilenet_v2_fpn_feature_extractor.py
...n/models/center_net_mobilenet_v2_fpn_feature_extractor.py
+12
-22
research/object_detection/models/center_net_mobilenet_v2_fpn_feature_extractor_tf2_test.py
...center_net_mobilenet_v2_fpn_feature_extractor_tf2_test.py
+41
-9
research/object_detection/models/center_net_resnet_feature_extractor.py
...t_detection/models/center_net_resnet_feature_extractor.py
+4
-2
research/object_detection/models/center_net_resnet_v1_fpn_feature_extractor.py
...tion/models/center_net_resnet_v1_fpn_feature_extractor.py
+8
-4
research/object_detection/protos/center_net.proto
research/object_detection/protos/center_net.proto
+7
-0
No files found.
research/object_detection/builders/model_builder.py
View file @
72d37b69
...
...
@@ -170,9 +170,6 @@ if tf_version.is_tf2():
center_net_mobilenet_v2_feature_extractor
.
mobilenet_v2
,
'mobilenet_v2_fpn'
:
center_net_mobilenet_v2_fpn_feature_extractor
.
mobilenet_v2_fpn
,
'mobilenet_v2_fpn_sep_conv'
:
center_net_mobilenet_v2_fpn_feature_extractor
.
mobilenet_v2_fpn_sep_conv
,
}
FEATURE_EXTRACTOR_MAPS
=
[
...
...
@@ -1130,6 +1127,8 @@ def _build_center_net_feature_extractor(feature_extractor_config, is_training):
'channel_means'
:
list
(
feature_extractor_config
.
channel_means
),
'channel_stds'
:
list
(
feature_extractor_config
.
channel_stds
),
'bgr_ordering'
:
feature_extractor_config
.
bgr_ordering
,
'depth_multiplier'
:
feature_extractor_config
.
depth_multiplier
,
'use_separable_conv'
:
feature_extractor_config
.
use_separable_conv
,
}
...
...
research/object_detection/models/center_net_hourglass_feature_extractor.py
View file @
72d37b69
...
...
@@ -73,8 +73,9 @@ class CenterNetHourglassFeatureExtractor(
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
def
hourglass_10
(
channel_means
,
channel_stds
,
bgr_ordering
):
def
hourglass_10
(
channel_means
,
channel_stds
,
bgr_ordering
,
**
kwargs
):
"""The Hourglass-10 backbone for CenterNet."""
del
kwargs
network
=
hourglass_network
.
hourglass_10
(
num_channels
=
32
)
return
CenterNetHourglassFeatureExtractor
(
...
...
@@ -82,8 +83,9 @@ def hourglass_10(channel_means, channel_stds, bgr_ordering):
bgr_ordering
=
bgr_ordering
)
def
hourglass_20
(
channel_means
,
channel_stds
,
bgr_ordering
):
def
hourglass_20
(
channel_means
,
channel_stds
,
bgr_ordering
,
**
kwargs
):
"""The Hourglass-20 backbone for CenterNet."""
del
kwargs
network
=
hourglass_network
.
hourglass_20
(
num_channels
=
48
)
return
CenterNetHourglassFeatureExtractor
(
...
...
@@ -91,8 +93,9 @@ def hourglass_20(channel_means, channel_stds, bgr_ordering):
bgr_ordering
=
bgr_ordering
)
def
hourglass_32
(
channel_means
,
channel_stds
,
bgr_ordering
):
def
hourglass_32
(
channel_means
,
channel_stds
,
bgr_ordering
,
**
kwargs
):
"""The Hourglass-32 backbone for CenterNet."""
del
kwargs
network
=
hourglass_network
.
hourglass_32
(
num_channels
=
48
)
return
CenterNetHourglassFeatureExtractor
(
...
...
@@ -100,8 +103,9 @@ def hourglass_32(channel_means, channel_stds, bgr_ordering):
bgr_ordering
=
bgr_ordering
)
def
hourglass_52
(
channel_means
,
channel_stds
,
bgr_ordering
):
def
hourglass_52
(
channel_means
,
channel_stds
,
bgr_ordering
,
**
kwargs
):
"""The Hourglass-52 backbone for CenterNet."""
del
kwargs
network
=
hourglass_network
.
hourglass_52
(
num_channels
=
64
)
return
CenterNetHourglassFeatureExtractor
(
...
...
@@ -109,8 +113,9 @@ def hourglass_52(channel_means, channel_stds, bgr_ordering):
bgr_ordering
=
bgr_ordering
)
def
hourglass_104
(
channel_means
,
channel_stds
,
bgr_ordering
):
def
hourglass_104
(
channel_means
,
channel_stds
,
bgr_ordering
,
**
kwargs
):
"""The Hourglass-104 backbone for CenterNet."""
del
kwargs
# TODO(vighneshb): update hourglass_104 signature to match with other
# hourglass networks.
...
...
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
View file @
72d37b69
...
...
@@ -110,11 +110,17 @@ class CenterNetMobileNetV2FeatureExtractor(
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
def
mobilenet_v2
(
channel_means
,
channel_stds
,
bgr_ordering
):
def
mobilenet_v2
(
channel_means
,
channel_stds
,
bgr_ordering
,
depth_multiplier
=
1.0
,
**
kwargs
):
"""The MobileNetV2 backbone for CenterNet."""
del
kwargs
# We set 'is_training' to True for now.
network
=
mobilenetv2
.
mobilenet_v2
(
True
,
include_top
=
False
)
network
=
mobilenetv2
.
mobilenet_v2
(
batchnorm_training
=
True
,
alpha
=
depth_multiplier
,
include_top
=
False
,
weights
=
'imagenet'
if
depth_multiplier
==
1.0
else
None
)
return
CenterNetMobileNetV2FeatureExtractor
(
network
,
channel_means
=
channel_means
,
...
...
research/object_detection/models/center_net_mobilenet_v2_fpn_feature_extractor.py
View file @
72d37b69
...
...
@@ -39,7 +39,7 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
channel_means
=
(
0.
,
0.
,
0.
),
channel_stds
=
(
1.
,
1.
,
1.
),
bgr_ordering
=
False
,
fpn
_separable_conv
=
False
):
use
_separable_conv
=
False
):
"""Intializes the feature extractor.
Args:
...
...
@@ -50,7 +50,7 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
channel. Each channel will be divided by its standard deviation value.
bgr_ordering: bool, if set will change the channel ordering to be in the
[blue, red, green] order.
fpn
_separable_conv: If set to True, all convolutional layers in the FPN
use
_separable_conv: If set to True, all convolutional layers in the FPN
network will be replaced by separable convolutions.
"""
...
...
@@ -96,7 +96,7 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
# Merge.
top_down
=
top_down
+
residual
next_num_filters
=
num_filters_list
[
i
+
1
]
if
i
+
1
<=
2
else
24
if
fpn
_separable_conv
:
if
use
_separable_conv
:
conv
=
tf
.
keras
.
layers
.
SeparableConv2D
(
filters
=
next_num_filters
,
kernel_size
=
3
,
strides
=
1
,
padding
=
'same'
)
else
:
...
...
@@ -143,30 +143,20 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
return
1
def
mobilenet_v2_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
):
def
mobilenet_v2_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
,
use_separable_conv
=
False
,
depth_multiplier
=
1.0
,
**
kwargs
):
"""The MobileNetV2+FPN backbone for CenterNet."""
del
kwargs
# Set to batchnorm_training to True for now.
network
=
mobilenetv2
.
mobilenet_v2
(
batchnorm_training
=
True
,
include_top
=
False
)
network
=
mobilenetv2
.
mobilenet_v2
(
batchnorm_training
=
True
,
alpha
=
depth_multiplier
,
include_top
=
False
,
weights
=
'imagenet'
if
depth_multiplier
==
1.0
else
None
)
return
CenterNetMobileNetV2FPNFeatureExtractor
(
network
,
channel_means
=
channel_means
,
channel_stds
=
channel_stds
,
bgr_ordering
=
bgr_ordering
,
fpn_separable_conv
=
False
)
def
mobilenet_v2_fpn_sep_conv
(
channel_means
,
channel_stds
,
bgr_ordering
):
"""Same as mobilenet_v2_fpn except with separable convolution in FPN."""
# Setting batchnorm_training to True, which will use the correct
# BatchNormalization layer strategy based on the current Keras learning phase.
# TODO(yuhuic): expriment with True vs. False to understand it's effect in
# practice.
network
=
mobilenetv2
.
mobilenet_v2
(
batchnorm_training
=
True
,
include_top
=
False
)
return
CenterNetMobileNetV2FPNFeatureExtractor
(
network
,
channel_means
=
channel_means
,
channel_stds
=
channel_stds
,
bgr_ordering
=
bgr_ordering
,
fpn_separable_conv
=
True
)
use_separable_conv
=
use_separable_conv
)
research/object_detection/models/center_net_mobilenet_v2_fpn_feature_extractor_tf2_test.py
View file @
72d37b69
...
...
@@ -18,7 +18,6 @@ import numpy as np
import
tensorflow.compat.v1
as
tf
from
object_detection.models
import
center_net_mobilenet_v2_fpn_feature_extractor
from
object_detection.models.keras_models
import
mobilenet_v2
from
object_detection.utils
import
test_case
from
object_detection.utils
import
tf_version
...
...
@@ -28,10 +27,13 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
def
test_center_net_mobilenet_v2_fpn_feature_extractor
(
self
):
net
=
mobilenet_v2
.
mobilenet_v2
(
True
,
include_top
=
False
)
model
=
center_net_mobilenet_v2_fpn_feature_extractor
.
CenterNetMobileNetV2FPNFeatureExtractor
(
net
)
channel_means
=
(
0.
,
0.
,
0.
)
channel_stds
=
(
1.
,
1.
,
1.
)
bgr_ordering
=
False
model
=
(
center_net_mobilenet_v2_fpn_feature_extractor
.
mobilenet_v2_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
,
use_separable_conv
=
False
))
def
graph_fn
():
img
=
np
.
zeros
((
8
,
224
,
224
,
3
),
dtype
=
np
.
float32
)
...
...
@@ -50,10 +52,12 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
def
test_center_net_mobilenet_v2_fpn_feature_extractor_sep_conv
(
self
):
net
=
mobilenet_v2
.
mobilenet_v2
(
True
,
include_top
=
False
)
model
=
center_net_mobilenet_v2_fpn_feature_extractor
.
CenterNetMobileNetV2FPNFeatureExtractor
(
net
,
fpn_separable_conv
=
True
)
channel_means
=
(
0.
,
0.
,
0.
)
channel_stds
=
(
1.
,
1.
,
1.
)
bgr_ordering
=
False
model
=
(
center_net_mobilenet_v2_fpn_feature_extractor
.
mobilenet_v2_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
,
use_separable_conv
=
True
))
def
graph_fn
():
img
=
np
.
zeros
((
8
,
224
,
224
,
3
),
dtype
=
np
.
float32
)
...
...
@@ -62,6 +66,10 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
outputs
=
self
.
execute
(
graph_fn
,
[])
self
.
assertEqual
(
outputs
.
shape
,
(
8
,
56
,
56
,
24
))
# Pull out the FPN network.
backbone
=
model
.
get_layer
(
'model'
)
first_conv
=
backbone
.
get_layer
(
'Conv1'
)
self
.
assertEqual
(
32
,
first_conv
.
filters
)
# Pull out the FPN network.
output
=
model
.
get_layer
(
'model_1'
)
...
...
@@ -71,6 +79,30 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
if
'conv'
in
layer
.
name
and
layer
.
kernel_size
!=
(
1
,
1
):
self
.
assertIsInstance
(
layer
,
tf
.
keras
.
layers
.
SeparableConv2D
)
def
test_center_net_mobilenet_v2_fpn_feature_extractor_depth_multiplier
(
self
):
channel_means
=
(
0.
,
0.
,
0.
)
channel_stds
=
(
1.
,
1.
,
1.
)
bgr_ordering
=
False
model
=
(
center_net_mobilenet_v2_fpn_feature_extractor
.
mobilenet_v2_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
,
use_separable_conv
=
True
,
depth_multiplier
=
2.0
))
def
graph_fn
():
img
=
np
.
zeros
((
8
,
224
,
224
,
3
),
dtype
=
np
.
float32
)
processed_img
=
model
.
preprocess
(
img
)
return
model
(
processed_img
)
outputs
=
self
.
execute
(
graph_fn
,
[])
self
.
assertEqual
(
outputs
.
shape
,
(
8
,
56
,
56
,
24
))
# Pull out the FPN network.
backbone
=
model
.
get_layer
(
'model'
)
first_conv
=
backbone
.
get_layer
(
'Conv1'
)
# Note that the first layer typically has 32 filters, but this model has
# a depth multiplier of 2.
self
.
assertEqual
(
64
,
first_conv
.
filters
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/object_detection/models/center_net_resnet_feature_extractor.py
View file @
72d37b69
...
...
@@ -136,8 +136,9 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
def
resnet_v2_101
(
channel_means
,
channel_stds
,
bgr_ordering
):
def
resnet_v2_101
(
channel_means
,
channel_stds
,
bgr_ordering
,
**
kwargs
):
"""The ResNet v2 101 feature extractor."""
del
kwargs
return
CenterNetResnetFeatureExtractor
(
resnet_type
=
'resnet_v2_101'
,
...
...
@@ -147,8 +148,9 @@ def resnet_v2_101(channel_means, channel_stds, bgr_ordering):
)
def
resnet_v2_50
(
channel_means
,
channel_stds
,
bgr_ordering
):
def
resnet_v2_50
(
channel_means
,
channel_stds
,
bgr_ordering
,
**
kwargs
):
"""The ResNet v2 50 feature extractor."""
del
kwargs
return
CenterNetResnetFeatureExtractor
(
resnet_type
=
'resnet_v2_50'
,
...
...
research/object_detection/models/center_net_resnet_v1_fpn_feature_extractor.py
View file @
72d37b69
...
...
@@ -172,8 +172,9 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
def
resnet_v1_101_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
):
def
resnet_v1_101_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
,
**
kwargs
):
"""The ResNet v1 101 FPN feature extractor."""
del
kwargs
return
CenterNetResnetV1FpnFeatureExtractor
(
resnet_type
=
'resnet_v1_101'
,
...
...
@@ -183,8 +184,9 @@ def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering):
)
def
resnet_v1_50_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
):
def
resnet_v1_50_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
,
**
kwargs
):
"""The ResNet v1 50 FPN feature extractor."""
del
kwargs
return
CenterNetResnetV1FpnFeatureExtractor
(
resnet_type
=
'resnet_v1_50'
,
...
...
@@ -193,8 +195,9 @@ def resnet_v1_50_fpn(channel_means, channel_stds, bgr_ordering):
bgr_ordering
=
bgr_ordering
)
def
resnet_v1_34_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
):
def
resnet_v1_34_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
,
**
kwargs
):
"""The ResNet v1 34 FPN feature extractor."""
del
kwargs
return
CenterNetResnetV1FpnFeatureExtractor
(
resnet_type
=
'resnet_v1_34'
,
...
...
@@ -204,8 +207,9 @@ def resnet_v1_34_fpn(channel_means, channel_stds, bgr_ordering):
)
def
resnet_v1_18_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
):
def
resnet_v1_18_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
,
**
kwargs
):
"""The ResNet v1 18 FPN feature extractor."""
del
kwargs
return
CenterNetResnetV1FpnFeatureExtractor
(
resnet_type
=
'resnet_v1_18'
,
...
...
research/object_detection/protos/center_net.proto
View file @
72d37b69
...
...
@@ -376,5 +376,12 @@ message CenterNetFeatureExtractor {
// network if any.
optional
bool
use_depthwise
=
5
[
default
=
false
];
// Depth multiplier. Only valid for specific models (e.g. MobileNet). See subclasses of `CenterNetFeatureExtractor`.
optional
float
depth_multiplier
=
9
[
default
=
1.0
];
// Whether to use separable convolutions. Only valid for specific
// models. See subclasses of `CenterNetFeatureExtractor`.
optional
bool
use_separable_conv
=
10
[
default
=
false
];
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment