Commit ec52ef2d authored by Ronny Votel's avatar Ronny Votel Committed by TF Object Detection Team
Browse files

Adding head customization parameters to `ObjectDetection` and `MaskEstimation`...

Adding head customization parameters to `ObjectDetection` and `MaskEstimation` CenterNet proto messages.

PiperOrigin-RevId: 377280733
parent b185135f
......@@ -926,11 +926,27 @@ def object_detection_proto_to_params(od_config):
losses_pb2.WeightedSigmoidClassificationLoss())
loss.localization_loss.CopyFrom(od_config.localization_loss)
_, localization_loss, _, _, _, _, _ = (losses_builder.build(loss))
if od_config.HasField('scale_head_params'):
scale_head_num_filters = list(od_config.scale_head_params.num_filters)
scale_head_kernel_sizes = list(od_config.scale_head_params.kernel_sizes)
else:
scale_head_num_filters = [256]
scale_head_kernel_sizes = [3]
if od_config.HasField('offset_head_params'):
offset_head_num_filters = list(od_config.offset_head_params.num_filters)
offset_head_kernel_sizes = list(od_config.offset_head_params.kernel_sizes)
else:
offset_head_num_filters = [256]
offset_head_kernel_sizes = [3]
return center_net_meta_arch.ObjectDetectionParams(
localization_loss=localization_loss,
scale_loss_weight=od_config.scale_loss_weight,
offset_loss_weight=od_config.offset_loss_weight,
task_loss_weight=od_config.task_loss_weight)
task_loss_weight=od_config.task_loss_weight,
scale_head_num_filters=scale_head_num_filters,
scale_head_kernel_sizes=scale_head_kernel_sizes,
offset_head_num_filters=offset_head_num_filters,
offset_head_kernel_sizes=offset_head_kernel_sizes)
def object_center_proto_to_params(oc_config):
......@@ -973,13 +989,21 @@ def mask_proto_to_params(mask_config):
losses_pb2.WeightedL2LocalizationLoss())
loss.classification_loss.CopyFrom(mask_config.classification_loss)
classification_loss, _, _, _, _, _, _ = (losses_builder.build(loss))
if mask_config.HasField('mask_head_params'):
mask_head_num_filters = list(mask_config.mask_head_params.num_filters)
mask_head_kernel_sizes = list(mask_config.mask_head_params.kernel_sizes)
else:
mask_head_num_filters = [256]
mask_head_kernel_sizes = [3]
return center_net_meta_arch.MaskParams(
classification_loss=classification_loss,
task_loss_weight=mask_config.task_loss_weight,
mask_height=mask_config.mask_height,
mask_width=mask_config.mask_width,
score_threshold=mask_config.score_threshold,
heatmap_bias_init=mask_config.heatmap_bias_init)
heatmap_bias_init=mask_config.heatmap_bias_init,
mask_head_num_filters=mask_head_num_filters,
mask_head_kernel_sizes=mask_head_kernel_sizes)
def densepose_proto_to_params(densepose_config):
......
......@@ -188,7 +188,7 @@ class ModelBuilderTF2Test(
return text_format.Merge(proto_txt,
center_net_pb2.CenterNet.ObjectCenterParams())
def get_fake_object_detection_proto(self):
def get_fake_object_detection_proto(self, customize_head_params=False):
proto_txt = """
task_loss_weight: 0.5
offset_loss_weight: 0.1
......@@ -198,10 +198,19 @@ class ModelBuilderTF2Test(
}
}
"""
if customize_head_params:
proto_txt += """
scale_head_params {
num_filters: 128
num_filters: 64
kernel_sizes: 5
kernel_sizes: 3
}
"""
return text_format.Merge(proto_txt,
center_net_pb2.CenterNet.ObjectDetection())
def get_fake_mask_proto(self):
def get_fake_mask_proto(self, customize_head_params=False):
proto_txt = """
task_loss_weight: 0.7
classification_loss {
......@@ -212,6 +221,15 @@ class ModelBuilderTF2Test(
score_threshold: 0.7
heatmap_bias_init: -2.0
"""
if customize_head_params:
proto_txt += """
mask_head_params {
num_filters: 128
num_filters: 64
kernel_sizes: 5
kernel_sizes: 3
}
"""
return text_format.Merge(proto_txt,
center_net_pb2.CenterNet.MaskEstimation())
......@@ -266,14 +284,16 @@ class ModelBuilderTF2Test(
self.get_fake_object_center_proto(
customize_head_params=customize_head_params))
config.center_net.object_detection_task.CopyFrom(
self.get_fake_object_detection_proto())
self.get_fake_object_detection_proto(
customize_head_params=customize_head_params))
config.center_net.keypoint_estimation_task.append(
self.get_fake_keypoint_proto(
customize_head_params=customize_head_params))
config.center_net.keypoint_label_map_path = (
self.get_fake_label_map_file_path())
config.center_net.mask_estimation_task.CopyFrom(
self.get_fake_mask_proto())
self.get_fake_mask_proto(
customize_head_params=customize_head_params))
config.center_net.densepose_estimation_task.CopyFrom(
self.get_fake_densepose_proto())
......@@ -303,6 +323,14 @@ class ModelBuilderTF2Test(
self.assertAlmostEqual(model._od_params.task_loss_weight, 0.5)
self.assertIsInstance(model._od_params.localization_loss,
losses.L1LocalizationLoss)
self.assertEqual(model._od_params.offset_head_num_filters, [256])
self.assertEqual(model._od_params.offset_head_kernel_sizes, [3])
if customize_head_params:
self.assertEqual(model._od_params.scale_head_num_filters, [128, 64])
self.assertEqual(model._od_params.scale_head_kernel_sizes, [5, 3])
else:
self.assertEqual(model._od_params.scale_head_num_filters, [256])
self.assertEqual(model._od_params.scale_head_kernel_sizes, [3])
# Check keypoint estimation related parameters.
kp_params = model._kp_params_dict['human_pose']
......@@ -352,6 +380,12 @@ class ModelBuilderTF2Test(
self.assertAlmostEqual(model._mask_params.score_threshold, 0.7)
self.assertAlmostEqual(
model._mask_params.heatmap_bias_init, -2.0, places=4)
if customize_head_params:
self.assertEqual(model._mask_params.mask_head_num_filters, [128, 64])
self.assertEqual(model._mask_params.mask_head_kernel_sizes, [5, 3])
else:
self.assertEqual(model._mask_params.mask_head_num_filters, [256])
self.assertEqual(model._mask_params.mask_head_kernel_sizes, [3])
# Check DensePose related parameters.
self.assertEqual(model._densepose_params.class_id, 0)
......
......@@ -1668,7 +1668,9 @@ def predicted_embeddings_at_object_centers(embedding_predictions,
class ObjectDetectionParams(
collections.namedtuple('ObjectDetectionParams', [
'localization_loss', 'scale_loss_weight', 'offset_loss_weight',
'task_loss_weight'
'task_loss_weight', 'scale_head_num_filters',
'scale_head_kernel_sizes', 'offset_head_num_filters',
'offset_head_kernel_sizes'
])):
"""Namedtuple to host object detection related parameters.
......@@ -1684,7 +1686,11 @@ class ObjectDetectionParams(
localization_loss,
scale_loss_weight,
offset_loss_weight,
task_loss_weight=1.0):
task_loss_weight=1.0,
scale_head_num_filters=(256),
scale_head_kernel_sizes=(3),
offset_head_num_filters=(256),
offset_head_kernel_sizes=(3)):
"""Constructor with default values for ObjectDetectionParams.
Args:
......@@ -1697,13 +1703,23 @@ class ObjectDetectionParams(
depending on the input size.
offset_loss_weight: float, The weight for localizing center offsets.
task_loss_weight: float, the weight of the object detection loss.
scale_head_num_filters: filter numbers of the convolutional layers used
by the object detection box scale prediction head.
scale_head_kernel_sizes: kernel size of the convolutional layers used
by the object detection box scale prediction head.
offset_head_num_filters: filter numbers of the convolutional layers used
by the object detection box offset prediction head.
offset_head_kernel_sizes: kernel size of the convolutional layers used
by the object detection box offset prediction head.
Returns:
An initialized ObjectDetectionParams namedtuple.
"""
return super(ObjectDetectionParams,
cls).__new__(cls, localization_loss, scale_loss_weight,
offset_loss_weight, task_loss_weight)
offset_loss_weight, task_loss_weight,
scale_head_num_filters, scale_head_kernel_sizes,
offset_head_num_filters, offset_head_kernel_sizes)
class KeypointEstimationParams(
......@@ -1937,7 +1953,8 @@ class ObjectCenterParams(
class MaskParams(
collections.namedtuple('MaskParams', [
'classification_loss', 'task_loss_weight', 'mask_height', 'mask_width',
'score_threshold', 'heatmap_bias_init'
'score_threshold', 'heatmap_bias_init', 'mask_head_num_filters',
'mask_head_kernel_sizes'
])):
"""Namedtuple to store mask prediction related parameters."""
......@@ -1949,7 +1966,9 @@ class MaskParams(
mask_height=256,
mask_width=256,
score_threshold=0.5,
heatmap_bias_init=-2.19):
heatmap_bias_init=-2.19,
mask_head_num_filters=(256),
mask_head_kernel_sizes=(3)):
"""Constructor with default values for MaskParams.
Args:
......@@ -1963,6 +1982,10 @@ class MaskParams(
heatmap_bias_init: float, the initial value of bias in the convolutional
kernel of the semantic segmentation prediction head. If set to None, the
bias is initialized with zeros.
mask_head_num_filters: filter numbers of the convolutional layers used
by the mask prediction head.
mask_head_kernel_sizes: kernel size of the convolutional layers used
by the mask prediction head.
Returns:
An initialized MaskParams namedtuple.
......@@ -1970,7 +1993,8 @@ class MaskParams(
return super(MaskParams,
cls).__new__(cls, classification_loss,
task_loss_weight, mask_height, mask_width,
score_threshold, heatmap_bias_init)
score_threshold, heatmap_bias_init,
mask_head_num_filters, mask_head_kernel_sizes)
class DensePoseParams(
......@@ -2312,10 +2336,18 @@ class CenterNetMetaArch(model.DetectionModel):
if self._od_params is not None:
prediction_heads[BOX_SCALE] = self._make_prediction_net_list(
num_feature_outputs, NUM_SIZE_CHANNELS, name='box_scale',
num_feature_outputs,
NUM_SIZE_CHANNELS,
kernel_sizes=self._od_params.scale_head_kernel_sizes,
num_filters=self._od_params.scale_head_num_filters,
name='box_scale',
unit_height_conv=unit_height_conv)
prediction_heads[BOX_OFFSET] = self._make_prediction_net_list(
num_feature_outputs, NUM_OFFSET_CHANNELS, name='box_offset',
num_feature_outputs,
NUM_OFFSET_CHANNELS,
kernel_sizes=self._od_params.offset_head_kernel_sizes,
num_filters=self._od_params.offset_head_num_filters,
name='box_offset',
unit_height_conv=unit_height_conv)
if self._kp_params_dict is not None:
......@@ -2370,6 +2402,8 @@ class CenterNetMetaArch(model.DetectionModel):
prediction_heads[SEGMENTATION_HEATMAP] = self._make_prediction_net_list(
num_feature_outputs,
num_classes,
kernel_sizes=self._mask_params.mask_head_kernel_sizes,
num_filters=self._mask_params.mask_head_num_filters,
bias_fill=self._mask_params.heatmap_bias_init,
name='seg_heatmap',
unit_height_conv=unit_height_conv)
......
......@@ -1539,7 +1539,9 @@ def get_fake_mask_params():
classification_loss=losses.WeightedSoftmaxClassificationLoss(),
task_loss_weight=1.0,
mask_height=4,
mask_width=4)
mask_width=4,
mask_head_num_filters=[96],
mask_head_kernel_sizes=[3])
def get_fake_densepose_params():
......
......@@ -65,6 +65,14 @@ message CenterNet {
// Localization loss configuration for object scale and offset losses.
optional LocalizationLoss localization_loss = 8;
// Parameters to determine the architecture of the object scale prediction
// head.
optional PredictionHeadParams scale_head_params = 9;
// Parameters to determine the architecture of the object offset prediction
// head.
optional PredictionHeadParams offset_head_params = 10;
}
optional ObjectDetection object_detection_task = 4;
......@@ -268,6 +276,10 @@ message CenterNet {
// prediction head. -2.19 corresponds to predicting foreground with
// a probability of 0.1.
optional float heatmap_bias_init = 3 [default = -2.19];
// Parameters to determine the architecture of the segmentation mask
// prediction head.
optional PredictionHeadParams mask_head_params = 7;
}
optional MaskEstimation mask_estimation_task = 8;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment