Commit d7a784e6 authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Fixed the bug in the model builder of not setting the default values properly

for the object center head parameters in CenterNet.

PiperOrigin-RevId: 371783912
parent 6faf56a6
...@@ -946,7 +946,7 @@ def object_center_proto_to_params(oc_config): ...@@ -946,7 +946,7 @@ def object_center_proto_to_params(oc_config):
if oc_config.keypoint_weights_for_center: if oc_config.keypoint_weights_for_center:
keypoint_weights_for_center = list(oc_config.keypoint_weights_for_center) keypoint_weights_for_center = list(oc_config.keypoint_weights_for_center)
if oc_config.center_head_params: if oc_config.HasField('center_head_params'):
center_head_num_filters = list(oc_config.center_head_params.num_filters) center_head_num_filters = list(oc_config.center_head_params.num_filters)
center_head_kernel_sizes = list(oc_config.center_head_params.kernel_sizes) center_head_kernel_sizes = list(oc_config.center_head_params.kernel_sizes)
else: else:
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import os import os
import unittest import unittest
from absl.testing import parameterized
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from google.protobuf import text_format from google.protobuf import text_format
...@@ -32,7 +33,8 @@ from object_detection.utils import tf_version ...@@ -32,7 +33,8 @@ from object_detection.utils import tf_version
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') @unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest): class ModelBuilderTF2Test(
model_builder_test.ModelBuilderTest, parameterized.TestCase):
def default_ssd_feature_extractor(self): def default_ssd_feature_extractor(self):
return 'ssd_resnet50_v1_fpn_keras' return 'ssd_resnet50_v1_fpn_keras'
...@@ -79,7 +81,7 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest): ...@@ -79,7 +81,7 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
f.write(keypoint_spec_text) f.write(keypoint_spec_text)
return keypoint_label_map_path return keypoint_label_map_path
def get_fake_keypoint_proto(self): def get_fake_keypoint_proto(self, customize_head_params=False):
task_proto_txt = """ task_proto_txt = """
task_name: "human_pose" task_name: "human_pose"
task_loss_weight: 0.9 task_loss_weight: 0.9
...@@ -120,18 +122,27 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest): ...@@ -120,18 +122,27 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
predict_depth: true predict_depth: true
per_keypoint_depth: true per_keypoint_depth: true
keypoint_depth_loss_weight: 0.3 keypoint_depth_loss_weight: 0.3
"""
if customize_head_params:
task_proto_txt += """
heatmap_head_params { heatmap_head_params {
num_filters: 64 num_filters: 64
num_filters: 32 num_filters: 32
kernel_sizes: 5 kernel_sizes: 5
kernel_sizes: 3 kernel_sizes: 3
} }
offset_head_params {
num_filters: 128
num_filters: 64
kernel_sizes: 5
kernel_sizes: 3
}
""" """
config = text_format.Merge(task_proto_txt, config = text_format.Merge(task_proto_txt,
center_net_pb2.CenterNet.KeypointEstimation()) center_net_pb2.CenterNet.KeypointEstimation())
return config return config
def get_fake_object_center_proto(self): def get_fake_object_center_proto(self, customize_head_params=False):
proto_txt = """ proto_txt = """
object_center_loss_weight: 0.5 object_center_loss_weight: 0.5
heatmap_bias_init: 3.14 heatmap_bias_init: 3.14
...@@ -143,6 +154,9 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest): ...@@ -143,6 +154,9 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
beta: 4.0 beta: 4.0
} }
} }
"""
if customize_head_params:
proto_txt += """
center_head_params { center_head_params {
num_filters: 64 num_filters: 64
num_filters: 32 num_filters: 32
...@@ -222,7 +236,11 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest): ...@@ -222,7 +236,11 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
return text_format.Merge(proto_txt, return text_format.Merge(proto_txt,
center_net_pb2.CenterNet.DensePoseEstimation()) center_net_pb2.CenterNet.DensePoseEstimation())
def test_create_center_net_model(self): @parameterized.parameters(
{'customize_head_params': True},
{'customize_head_params': False}
)
def test_create_center_net_model(self, customize_head_params):
"""Test building a CenterNet model from proto txt.""" """Test building a CenterNet model from proto txt."""
proto_txt = """ proto_txt = """
center_net { center_net {
...@@ -244,11 +262,13 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest): ...@@ -244,11 +262,13 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
# Set up the configuration proto. # Set up the configuration proto.
config = text_format.Merge(proto_txt, model_pb2.DetectionModel()) config = text_format.Merge(proto_txt, model_pb2.DetectionModel())
config.center_net.object_center_params.CopyFrom( config.center_net.object_center_params.CopyFrom(
self.get_fake_object_center_proto()) self.get_fake_object_center_proto(
customize_head_params=customize_head_params))
config.center_net.object_detection_task.CopyFrom( config.center_net.object_detection_task.CopyFrom(
self.get_fake_object_detection_proto()) self.get_fake_object_detection_proto())
config.center_net.keypoint_estimation_task.append( config.center_net.keypoint_estimation_task.append(
self.get_fake_keypoint_proto()) self.get_fake_keypoint_proto(
customize_head_params=customize_head_params))
config.center_net.keypoint_label_map_path = ( config.center_net.keypoint_label_map_path = (
self.get_fake_label_map_file_path()) self.get_fake_label_map_file_path())
config.center_net.mask_estimation_task.CopyFrom( config.center_net.mask_estimation_task.CopyFrom(
...@@ -269,8 +289,12 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest): ...@@ -269,8 +289,12 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
self.assertAlmostEqual( self.assertAlmostEqual(
model._center_params.heatmap_bias_init, 3.14, places=4) model._center_params.heatmap_bias_init, 3.14, places=4)
self.assertEqual(model._center_params.max_box_predictions, 15) self.assertEqual(model._center_params.max_box_predictions, 15)
if customize_head_params:
self.assertEqual(model._center_params.center_head_num_filters, [64, 32]) self.assertEqual(model._center_params.center_head_num_filters, [64, 32])
self.assertEqual(model._center_params.center_head_kernel_sizes, [5, 3]) self.assertEqual(model._center_params.center_head_kernel_sizes, [5, 3])
else:
self.assertEqual(model._center_params.center_head_num_filters, [256])
self.assertEqual(model._center_params.center_head_kernel_sizes, [3])
# Check object detection related parameters. # Check object detection related parameters.
self.assertAlmostEqual(model._od_params.offset_loss_weight, 0.1) self.assertAlmostEqual(model._od_params.offset_loss_weight, 0.1)
...@@ -305,10 +329,16 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest): ...@@ -305,10 +329,16 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
self.assertEqual(kp_params.predict_depth, True) self.assertEqual(kp_params.predict_depth, True)
self.assertEqual(kp_params.per_keypoint_depth, True) self.assertEqual(kp_params.per_keypoint_depth, True)
self.assertAlmostEqual(kp_params.keypoint_depth_loss_weight, 0.3) self.assertAlmostEqual(kp_params.keypoint_depth_loss_weight, 0.3)
if customize_head_params:
# Set by the config. # Set by the config.
self.assertEqual(kp_params.heatmap_head_num_filters, [64, 32]) self.assertEqual(kp_params.heatmap_head_num_filters, [64, 32])
self.assertEqual(kp_params.heatmap_head_kernel_sizes, [5, 3]) self.assertEqual(kp_params.heatmap_head_kernel_sizes, [5, 3])
self.assertEqual(kp_params.offset_head_num_filters, [128, 64])
self.assertEqual(kp_params.offset_head_kernel_sizes, [5, 3])
else:
# Default values: # Default values:
self.assertEqual(kp_params.heatmap_head_num_filters, [256])
self.assertEqual(kp_params.heatmap_head_kernel_sizes, [3])
self.assertEqual(kp_params.offset_head_num_filters, [256]) self.assertEqual(kp_params.offset_head_num_filters, [256])
self.assertEqual(kp_params.offset_head_kernel_sizes, [3]) self.assertEqual(kp_params.offset_head_kernel_sizes, [3])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment