"docs/vscode:/vscode.git/clone" did not exist on "bc80dc4ce0ae5a97b8a58faa1d8b8cfbb56e21f5"
Commit 82b922b4 authored by Zhichao Lu's avatar Zhichao Lu Committed by TF Object Detection Team
Browse files

Enable hourglass 52 for CenterNet models

PiperOrigin-RevId: 362253972
parent bee6a471
......@@ -153,6 +153,14 @@ if tf_version.is_tf2():
center_net_resnet_v1_fpn_feature_extractor.resnet_v1_50_fpn,
'resnet_v1_101_fpn':
center_net_resnet_v1_fpn_feature_extractor.resnet_v1_101_fpn,
'hourglass_10':
center_net_hourglass_feature_extractor.hourglass_10,
'hourglass_20':
center_net_hourglass_feature_extractor.hourglass_20,
'hourglass_32':
center_net_hourglass_feature_extractor.hourglass_32,
'hourglass_52':
center_net_hourglass_feature_extractor.hourglass_52,
'hourglass_104':
center_net_hourglass_feature_extractor.hourglass_104,
'mobilenet_v2':
......
......@@ -24,7 +24,8 @@ from google.protobuf import text_format
from object_detection.builders import model_builder
from object_detection.builders import model_builder_test
from object_detection.core import losses
from object_detection.models import center_net_resnet_feature_extractor
from object_detection.models import center_net_hourglass_feature_extractor
from object_detection.models.keras_models import hourglass_network
from object_detection.protos import center_net_pb2
from object_detection.protos import model_pb2
from object_detection.utils import tf_version
......@@ -195,7 +196,7 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
center_net {
num_classes: 10
feature_extractor {
type: "resnet_v2_101"
type: "hourglass_52"
channel_stds: [4, 5, 6]
bgr_ordering: true
}
......@@ -298,11 +299,14 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
# Check feature extractor parameters.
self.assertIsInstance(
model._feature_extractor,
center_net_resnet_feature_extractor.CenterNetResnetFeatureExtractor)
model._feature_extractor, center_net_hourglass_feature_extractor
.CenterNetHourglassFeatureExtractor)
self.assertAllClose(model._feature_extractor._channel_means, [0, 0, 0])
self.assertAllClose(model._feature_extractor._channel_stds, [4, 5, 6])
self.assertTrue(model._feature_extractor._bgr_ordering)
backbone = model._feature_extractor._network
self.assertIsInstance(backbone, hourglass_network.HourglassNetwork)
self.assertTrue(backbone.num_hourglasses, 1)
if __name__ == '__main__':
......
......@@ -73,9 +73,47 @@ class CenterNetHourglassFeatureExtractor(
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def hourglass_10(channel_means, channel_stds, bgr_ordering):
"""The Hourglass-10 backbone for CenterNet."""
network = hourglass_network.hourglass_10(num_channels=128)
return CenterNetHourglassFeatureExtractor(
network, channel_means=channel_means, channel_stds=channel_stds,
bgr_ordering=bgr_ordering)
def hourglass_20(channel_means, channel_stds, bgr_ordering):
"""The Hourglass-20 backbone for CenterNet."""
network = hourglass_network.hourglass_20(num_channels=128)
return CenterNetHourglassFeatureExtractor(
network, channel_means=channel_means, channel_stds=channel_stds,
bgr_ordering=bgr_ordering)
def hourglass_32(channel_means, channel_stds, bgr_ordering):
"""The Hourglass-52 backbone for CenterNet."""
network = hourglass_network.hourglass_32(num_channels=128)
return CenterNetHourglassFeatureExtractor(
network, channel_means=channel_means, channel_stds=channel_stds,
bgr_ordering=bgr_ordering)
def hourglass_52(channel_means, channel_stds, bgr_ordering):
"""The Hourglass-52 backbone for CenterNet."""
network = hourglass_network.hourglass_52(num_channels=128)
return CenterNetHourglassFeatureExtractor(
network, channel_means=channel_means, channel_stds=channel_stds,
bgr_ordering=bgr_ordering)
def hourglass_104(channel_means, channel_stds, bgr_ordering):
"""The Hourglass-104 backbone for CenterNet."""
# TODO(vighneshb): update hourglass_104 signature to match with other
# hourglass networks.
network = hourglass_network.hourglass_104()
return CenterNetHourglassFeatureExtractor(
network, channel_means=channel_means, channel_stds=channel_stds,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment