Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
82b922b4
Commit
82b922b4
authored
Mar 11, 2021
by
Zhichao Lu
Committed by
TF Object Detection Team
Mar 11, 2021
Browse files
Enable hourglass 52 for CenterNet models
PiperOrigin-RevId: 362253972
parent
bee6a471
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
4 deletions
+54
-4
research/object_detection/builders/model_builder.py
research/object_detection/builders/model_builder.py
+8
-0
research/object_detection/builders/model_builder_tf2_test.py
research/object_detection/builders/model_builder_tf2_test.py
+8
-4
research/object_detection/models/center_net_hourglass_feature_extractor.py
...etection/models/center_net_hourglass_feature_extractor.py
+38
-0
No files found.
research/object_detection/builders/model_builder.py
View file @
82b922b4
...
@@ -153,6 +153,14 @@ if tf_version.is_tf2():
...
@@ -153,6 +153,14 @@ if tf_version.is_tf2():
center_net_resnet_v1_fpn_feature_extractor
.
resnet_v1_50_fpn
,
center_net_resnet_v1_fpn_feature_extractor
.
resnet_v1_50_fpn
,
'resnet_v1_101_fpn'
:
'resnet_v1_101_fpn'
:
center_net_resnet_v1_fpn_feature_extractor
.
resnet_v1_101_fpn
,
center_net_resnet_v1_fpn_feature_extractor
.
resnet_v1_101_fpn
,
'hourglass_10'
:
center_net_hourglass_feature_extractor
.
hourglass_10
,
'hourglass_20'
:
center_net_hourglass_feature_extractor
.
hourglass_20
,
'hourglass_32'
:
center_net_hourglass_feature_extractor
.
hourglass_32
,
'hourglass_52'
:
center_net_hourglass_feature_extractor
.
hourglass_52
,
'hourglass_104'
:
'hourglass_104'
:
center_net_hourglass_feature_extractor
.
hourglass_104
,
center_net_hourglass_feature_extractor
.
hourglass_104
,
'mobilenet_v2'
:
'mobilenet_v2'
:
...
...
research/object_detection/builders/model_builder_tf2_test.py
View file @
82b922b4
...
@@ -24,7 +24,8 @@ from google.protobuf import text_format
...
@@ -24,7 +24,8 @@ from google.protobuf import text_format
from
object_detection.builders
import
model_builder
from
object_detection.builders
import
model_builder
from
object_detection.builders
import
model_builder_test
from
object_detection.builders
import
model_builder_test
from
object_detection.core
import
losses
from
object_detection.core
import
losses
from
object_detection.models
import
center_net_resnet_feature_extractor
from
object_detection.models
import
center_net_hourglass_feature_extractor
from
object_detection.models.keras_models
import
hourglass_network
from
object_detection.protos
import
center_net_pb2
from
object_detection.protos
import
center_net_pb2
from
object_detection.protos
import
model_pb2
from
object_detection.protos
import
model_pb2
from
object_detection.utils
import
tf_version
from
object_detection.utils
import
tf_version
...
@@ -195,7 +196,7 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
...
@@ -195,7 +196,7 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
center_net {
center_net {
num_classes: 10
num_classes: 10
feature_extractor {
feature_extractor {
type: "
resnet_v2_101
"
type: "
hourglass_52
"
channel_stds: [4, 5, 6]
channel_stds: [4, 5, 6]
bgr_ordering: true
bgr_ordering: true
}
}
...
@@ -298,11 +299,14 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
...
@@ -298,11 +299,14 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
# Check feature extractor parameters.
# Check feature extractor parameters.
self
.
assertIsInstance
(
self
.
assertIsInstance
(
model
.
_feature_extractor
,
model
.
_feature_extractor
,
center_net_hourglass_feature_extractor
center_net_resnet_feature_extractor
.
CenterNet
Resnet
FeatureExtractor
)
.
CenterNet
Hourglass
FeatureExtractor
)
self
.
assertAllClose
(
model
.
_feature_extractor
.
_channel_means
,
[
0
,
0
,
0
])
self
.
assertAllClose
(
model
.
_feature_extractor
.
_channel_means
,
[
0
,
0
,
0
])
self
.
assertAllClose
(
model
.
_feature_extractor
.
_channel_stds
,
[
4
,
5
,
6
])
self
.
assertAllClose
(
model
.
_feature_extractor
.
_channel_stds
,
[
4
,
5
,
6
])
self
.
assertTrue
(
model
.
_feature_extractor
.
_bgr_ordering
)
self
.
assertTrue
(
model
.
_feature_extractor
.
_bgr_ordering
)
backbone
=
model
.
_feature_extractor
.
_network
self
.
assertIsInstance
(
backbone
,
hourglass_network
.
HourglassNetwork
)
self
.
assertTrue
(
backbone
.
num_hourglasses
,
1
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
research/object_detection/models/center_net_hourglass_feature_extractor.py
View file @
82b922b4
...
@@ -73,9 +73,47 @@ class CenterNetHourglassFeatureExtractor(
...
@@ -73,9 +73,47 @@ class CenterNetHourglassFeatureExtractor(
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
def
hourglass_10
(
channel_means
,
channel_stds
,
bgr_ordering
):
"""The Hourglass-10 backbone for CenterNet."""
network
=
hourglass_network
.
hourglass_10
(
num_channels
=
128
)
return
CenterNetHourglassFeatureExtractor
(
network
,
channel_means
=
channel_means
,
channel_stds
=
channel_stds
,
bgr_ordering
=
bgr_ordering
)
def
hourglass_20
(
channel_means
,
channel_stds
,
bgr_ordering
):
"""The Hourglass-20 backbone for CenterNet."""
network
=
hourglass_network
.
hourglass_20
(
num_channels
=
128
)
return
CenterNetHourglassFeatureExtractor
(
network
,
channel_means
=
channel_means
,
channel_stds
=
channel_stds
,
bgr_ordering
=
bgr_ordering
)
def
hourglass_32
(
channel_means
,
channel_stds
,
bgr_ordering
):
"""The Hourglass-52 backbone for CenterNet."""
network
=
hourglass_network
.
hourglass_32
(
num_channels
=
128
)
return
CenterNetHourglassFeatureExtractor
(
network
,
channel_means
=
channel_means
,
channel_stds
=
channel_stds
,
bgr_ordering
=
bgr_ordering
)
def
hourglass_52
(
channel_means
,
channel_stds
,
bgr_ordering
):
"""The Hourglass-52 backbone for CenterNet."""
network
=
hourglass_network
.
hourglass_52
(
num_channels
=
128
)
return
CenterNetHourglassFeatureExtractor
(
network
,
channel_means
=
channel_means
,
channel_stds
=
channel_stds
,
bgr_ordering
=
bgr_ordering
)
def
hourglass_104
(
channel_means
,
channel_stds
,
bgr_ordering
):
def
hourglass_104
(
channel_means
,
channel_stds
,
bgr_ordering
):
"""The Hourglass-104 backbone for CenterNet."""
"""The Hourglass-104 backbone for CenterNet."""
# TODO(vighneshb): update hourglass_104 signature to match with other
# hourglass networks.
network
=
hourglass_network
.
hourglass_104
()
network
=
hourglass_network
.
hourglass_104
()
return
CenterNetHourglassFeatureExtractor
(
return
CenterNetHourglassFeatureExtractor
(
network
,
channel_means
=
channel_means
,
channel_stds
=
channel_stds
,
network
,
channel_means
=
channel_means
,
channel_stds
=
channel_stds
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment