Commit ec52ef2d authored by Ronny Votel's avatar Ronny Votel Committed by TF Object Detection Team
Browse files

Adding head customization parameters to `ObjectDetection` and `MaskEstimation`...

Adding head customization parameters to `ObjectDetection` and `MaskEstimation` CenterNet proto messages.

PiperOrigin-RevId: 377280733
parent b185135f
...@@ -926,11 +926,27 @@ def object_detection_proto_to_params(od_config): ...@@ -926,11 +926,27 @@ def object_detection_proto_to_params(od_config):
losses_pb2.WeightedSigmoidClassificationLoss()) losses_pb2.WeightedSigmoidClassificationLoss())
loss.localization_loss.CopyFrom(od_config.localization_loss) loss.localization_loss.CopyFrom(od_config.localization_loss)
_, localization_loss, _, _, _, _, _ = (losses_builder.build(loss)) _, localization_loss, _, _, _, _, _ = (losses_builder.build(loss))
if od_config.HasField('scale_head_params'):
scale_head_num_filters = list(od_config.scale_head_params.num_filters)
scale_head_kernel_sizes = list(od_config.scale_head_params.kernel_sizes)
else:
scale_head_num_filters = [256]
scale_head_kernel_sizes = [3]
if od_config.HasField('offset_head_params'):
offset_head_num_filters = list(od_config.offset_head_params.num_filters)
offset_head_kernel_sizes = list(od_config.offset_head_params.kernel_sizes)
else:
offset_head_num_filters = [256]
offset_head_kernel_sizes = [3]
return center_net_meta_arch.ObjectDetectionParams( return center_net_meta_arch.ObjectDetectionParams(
localization_loss=localization_loss, localization_loss=localization_loss,
scale_loss_weight=od_config.scale_loss_weight, scale_loss_weight=od_config.scale_loss_weight,
offset_loss_weight=od_config.offset_loss_weight, offset_loss_weight=od_config.offset_loss_weight,
task_loss_weight=od_config.task_loss_weight) task_loss_weight=od_config.task_loss_weight,
scale_head_num_filters=scale_head_num_filters,
scale_head_kernel_sizes=scale_head_kernel_sizes,
offset_head_num_filters=offset_head_num_filters,
offset_head_kernel_sizes=offset_head_kernel_sizes)
def object_center_proto_to_params(oc_config): def object_center_proto_to_params(oc_config):
...@@ -973,13 +989,21 @@ def mask_proto_to_params(mask_config): ...@@ -973,13 +989,21 @@ def mask_proto_to_params(mask_config):
losses_pb2.WeightedL2LocalizationLoss()) losses_pb2.WeightedL2LocalizationLoss())
loss.classification_loss.CopyFrom(mask_config.classification_loss) loss.classification_loss.CopyFrom(mask_config.classification_loss)
classification_loss, _, _, _, _, _, _ = (losses_builder.build(loss)) classification_loss, _, _, _, _, _, _ = (losses_builder.build(loss))
if mask_config.HasField('mask_head_params'):
mask_head_num_filters = list(mask_config.mask_head_params.num_filters)
mask_head_kernel_sizes = list(mask_config.mask_head_params.kernel_sizes)
else:
mask_head_num_filters = [256]
mask_head_kernel_sizes = [3]
return center_net_meta_arch.MaskParams( return center_net_meta_arch.MaskParams(
classification_loss=classification_loss, classification_loss=classification_loss,
task_loss_weight=mask_config.task_loss_weight, task_loss_weight=mask_config.task_loss_weight,
mask_height=mask_config.mask_height, mask_height=mask_config.mask_height,
mask_width=mask_config.mask_width, mask_width=mask_config.mask_width,
score_threshold=mask_config.score_threshold, score_threshold=mask_config.score_threshold,
heatmap_bias_init=mask_config.heatmap_bias_init) heatmap_bias_init=mask_config.heatmap_bias_init,
mask_head_num_filters=mask_head_num_filters,
mask_head_kernel_sizes=mask_head_kernel_sizes)
def densepose_proto_to_params(densepose_config): def densepose_proto_to_params(densepose_config):
......
...@@ -188,7 +188,7 @@ class ModelBuilderTF2Test( ...@@ -188,7 +188,7 @@ class ModelBuilderTF2Test(
return text_format.Merge(proto_txt, return text_format.Merge(proto_txt,
center_net_pb2.CenterNet.ObjectCenterParams()) center_net_pb2.CenterNet.ObjectCenterParams())
def get_fake_object_detection_proto(self): def get_fake_object_detection_proto(self, customize_head_params=False):
proto_txt = """ proto_txt = """
task_loss_weight: 0.5 task_loss_weight: 0.5
offset_loss_weight: 0.1 offset_loss_weight: 0.1
...@@ -198,10 +198,19 @@ class ModelBuilderTF2Test( ...@@ -198,10 +198,19 @@ class ModelBuilderTF2Test(
} }
} }
""" """
if customize_head_params:
proto_txt += """
scale_head_params {
num_filters: 128
num_filters: 64
kernel_sizes: 5
kernel_sizes: 3
}
"""
return text_format.Merge(proto_txt, return text_format.Merge(proto_txt,
center_net_pb2.CenterNet.ObjectDetection()) center_net_pb2.CenterNet.ObjectDetection())
def get_fake_mask_proto(self): def get_fake_mask_proto(self, customize_head_params=False):
proto_txt = """ proto_txt = """
task_loss_weight: 0.7 task_loss_weight: 0.7
classification_loss { classification_loss {
...@@ -212,6 +221,15 @@ class ModelBuilderTF2Test( ...@@ -212,6 +221,15 @@ class ModelBuilderTF2Test(
score_threshold: 0.7 score_threshold: 0.7
heatmap_bias_init: -2.0 heatmap_bias_init: -2.0
""" """
if customize_head_params:
proto_txt += """
mask_head_params {
num_filters: 128
num_filters: 64
kernel_sizes: 5
kernel_sizes: 3
}
"""
return text_format.Merge(proto_txt, return text_format.Merge(proto_txt,
center_net_pb2.CenterNet.MaskEstimation()) center_net_pb2.CenterNet.MaskEstimation())
...@@ -266,14 +284,16 @@ class ModelBuilderTF2Test( ...@@ -266,14 +284,16 @@ class ModelBuilderTF2Test(
self.get_fake_object_center_proto( self.get_fake_object_center_proto(
customize_head_params=customize_head_params)) customize_head_params=customize_head_params))
config.center_net.object_detection_task.CopyFrom( config.center_net.object_detection_task.CopyFrom(
self.get_fake_object_detection_proto()) self.get_fake_object_detection_proto(
customize_head_params=customize_head_params))
config.center_net.keypoint_estimation_task.append( config.center_net.keypoint_estimation_task.append(
self.get_fake_keypoint_proto( self.get_fake_keypoint_proto(
customize_head_params=customize_head_params)) customize_head_params=customize_head_params))
config.center_net.keypoint_label_map_path = ( config.center_net.keypoint_label_map_path = (
self.get_fake_label_map_file_path()) self.get_fake_label_map_file_path())
config.center_net.mask_estimation_task.CopyFrom( config.center_net.mask_estimation_task.CopyFrom(
self.get_fake_mask_proto()) self.get_fake_mask_proto(
customize_head_params=customize_head_params))
config.center_net.densepose_estimation_task.CopyFrom( config.center_net.densepose_estimation_task.CopyFrom(
self.get_fake_densepose_proto()) self.get_fake_densepose_proto())
...@@ -303,6 +323,14 @@ class ModelBuilderTF2Test( ...@@ -303,6 +323,14 @@ class ModelBuilderTF2Test(
self.assertAlmostEqual(model._od_params.task_loss_weight, 0.5) self.assertAlmostEqual(model._od_params.task_loss_weight, 0.5)
self.assertIsInstance(model._od_params.localization_loss, self.assertIsInstance(model._od_params.localization_loss,
losses.L1LocalizationLoss) losses.L1LocalizationLoss)
self.assertEqual(model._od_params.offset_head_num_filters, [256])
self.assertEqual(model._od_params.offset_head_kernel_sizes, [3])
if customize_head_params:
self.assertEqual(model._od_params.scale_head_num_filters, [128, 64])
self.assertEqual(model._od_params.scale_head_kernel_sizes, [5, 3])
else:
self.assertEqual(model._od_params.scale_head_num_filters, [256])
self.assertEqual(model._od_params.scale_head_kernel_sizes, [3])
# Check keypoint estimation related parameters. # Check keypoint estimation related parameters.
kp_params = model._kp_params_dict['human_pose'] kp_params = model._kp_params_dict['human_pose']
...@@ -352,6 +380,12 @@ class ModelBuilderTF2Test( ...@@ -352,6 +380,12 @@ class ModelBuilderTF2Test(
self.assertAlmostEqual(model._mask_params.score_threshold, 0.7) self.assertAlmostEqual(model._mask_params.score_threshold, 0.7)
self.assertAlmostEqual( self.assertAlmostEqual(
model._mask_params.heatmap_bias_init, -2.0, places=4) model._mask_params.heatmap_bias_init, -2.0, places=4)
if customize_head_params:
self.assertEqual(model._mask_params.mask_head_num_filters, [128, 64])
self.assertEqual(model._mask_params.mask_head_kernel_sizes, [5, 3])
else:
self.assertEqual(model._mask_params.mask_head_num_filters, [256])
self.assertEqual(model._mask_params.mask_head_kernel_sizes, [3])
# Check DensePose related parameters. # Check DensePose related parameters.
self.assertEqual(model._densepose_params.class_id, 0) self.assertEqual(model._densepose_params.class_id, 0)
......
...@@ -1668,7 +1668,9 @@ def predicted_embeddings_at_object_centers(embedding_predictions, ...@@ -1668,7 +1668,9 @@ def predicted_embeddings_at_object_centers(embedding_predictions,
class ObjectDetectionParams( class ObjectDetectionParams(
collections.namedtuple('ObjectDetectionParams', [ collections.namedtuple('ObjectDetectionParams', [
'localization_loss', 'scale_loss_weight', 'offset_loss_weight', 'localization_loss', 'scale_loss_weight', 'offset_loss_weight',
'task_loss_weight' 'task_loss_weight', 'scale_head_num_filters',
'scale_head_kernel_sizes', 'offset_head_num_filters',
'offset_head_kernel_sizes'
])): ])):
"""Namedtuple to host object detection related parameters. """Namedtuple to host object detection related parameters.
...@@ -1684,7 +1686,11 @@ class ObjectDetectionParams( ...@@ -1684,7 +1686,11 @@ class ObjectDetectionParams(
localization_loss, localization_loss,
scale_loss_weight, scale_loss_weight,
offset_loss_weight, offset_loss_weight,
task_loss_weight=1.0): task_loss_weight=1.0,
scale_head_num_filters=(256),
scale_head_kernel_sizes=(3),
offset_head_num_filters=(256),
offset_head_kernel_sizes=(3)):
"""Constructor with default values for ObjectDetectionParams. """Constructor with default values for ObjectDetectionParams.
Args: Args:
...@@ -1697,13 +1703,23 @@ class ObjectDetectionParams( ...@@ -1697,13 +1703,23 @@ class ObjectDetectionParams(
depending on the input size. depending on the input size.
offset_loss_weight: float, The weight for localizing center offsets. offset_loss_weight: float, The weight for localizing center offsets.
task_loss_weight: float, the weight of the object detection loss. task_loss_weight: float, the weight of the object detection loss.
scale_head_num_filters: filter numbers of the convolutional layers used
by the object detection box scale prediction head.
scale_head_kernel_sizes: kernel size of the convolutional layers used
by the object detection box scale prediction head.
offset_head_num_filters: filter numbers of the convolutional layers used
by the object detection box offset prediction head.
offset_head_kernel_sizes: kernel size of the convolutional layers used
by the object detection box offset prediction head.
Returns: Returns:
An initialized ObjectDetectionParams namedtuple. An initialized ObjectDetectionParams namedtuple.
""" """
return super(ObjectDetectionParams, return super(ObjectDetectionParams,
cls).__new__(cls, localization_loss, scale_loss_weight, cls).__new__(cls, localization_loss, scale_loss_weight,
offset_loss_weight, task_loss_weight) offset_loss_weight, task_loss_weight,
scale_head_num_filters, scale_head_kernel_sizes,
offset_head_num_filters, offset_head_kernel_sizes)
class KeypointEstimationParams( class KeypointEstimationParams(
...@@ -1937,7 +1953,8 @@ class ObjectCenterParams( ...@@ -1937,7 +1953,8 @@ class ObjectCenterParams(
class MaskParams( class MaskParams(
collections.namedtuple('MaskParams', [ collections.namedtuple('MaskParams', [
'classification_loss', 'task_loss_weight', 'mask_height', 'mask_width', 'classification_loss', 'task_loss_weight', 'mask_height', 'mask_width',
'score_threshold', 'heatmap_bias_init' 'score_threshold', 'heatmap_bias_init', 'mask_head_num_filters',
'mask_head_kernel_sizes'
])): ])):
"""Namedtuple to store mask prediction related parameters.""" """Namedtuple to store mask prediction related parameters."""
...@@ -1949,7 +1966,9 @@ class MaskParams( ...@@ -1949,7 +1966,9 @@ class MaskParams(
mask_height=256, mask_height=256,
mask_width=256, mask_width=256,
score_threshold=0.5, score_threshold=0.5,
heatmap_bias_init=-2.19): heatmap_bias_init=-2.19,
mask_head_num_filters=(256),
mask_head_kernel_sizes=(3)):
"""Constructor with default values for MaskParams. """Constructor with default values for MaskParams.
Args: Args:
...@@ -1963,6 +1982,10 @@ class MaskParams( ...@@ -1963,6 +1982,10 @@ class MaskParams(
heatmap_bias_init: float, the initial value of bias in the convolutional heatmap_bias_init: float, the initial value of bias in the convolutional
kernel of the semantic segmentation prediction head. If set to None, the kernel of the semantic segmentation prediction head. If set to None, the
bias is initialized with zeros. bias is initialized with zeros.
mask_head_num_filters: filter numbers of the convolutional layers used
by the mask prediction head.
mask_head_kernel_sizes: kernel size of the convolutional layers used
by the mask prediction head.
Returns: Returns:
An initialized MaskParams namedtuple. An initialized MaskParams namedtuple.
...@@ -1970,7 +1993,8 @@ class MaskParams( ...@@ -1970,7 +1993,8 @@ class MaskParams(
return super(MaskParams, return super(MaskParams,
cls).__new__(cls, classification_loss, cls).__new__(cls, classification_loss,
task_loss_weight, mask_height, mask_width, task_loss_weight, mask_height, mask_width,
score_threshold, heatmap_bias_init) score_threshold, heatmap_bias_init,
mask_head_num_filters, mask_head_kernel_sizes)
class DensePoseParams( class DensePoseParams(
...@@ -2312,10 +2336,18 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2312,10 +2336,18 @@ class CenterNetMetaArch(model.DetectionModel):
if self._od_params is not None: if self._od_params is not None:
prediction_heads[BOX_SCALE] = self._make_prediction_net_list( prediction_heads[BOX_SCALE] = self._make_prediction_net_list(
num_feature_outputs, NUM_SIZE_CHANNELS, name='box_scale', num_feature_outputs,
NUM_SIZE_CHANNELS,
kernel_sizes=self._od_params.scale_head_kernel_sizes,
num_filters=self._od_params.scale_head_num_filters,
name='box_scale',
unit_height_conv=unit_height_conv) unit_height_conv=unit_height_conv)
prediction_heads[BOX_OFFSET] = self._make_prediction_net_list( prediction_heads[BOX_OFFSET] = self._make_prediction_net_list(
num_feature_outputs, NUM_OFFSET_CHANNELS, name='box_offset', num_feature_outputs,
NUM_OFFSET_CHANNELS,
kernel_sizes=self._od_params.offset_head_kernel_sizes,
num_filters=self._od_params.offset_head_num_filters,
name='box_offset',
unit_height_conv=unit_height_conv) unit_height_conv=unit_height_conv)
if self._kp_params_dict is not None: if self._kp_params_dict is not None:
...@@ -2370,6 +2402,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2370,6 +2402,8 @@ class CenterNetMetaArch(model.DetectionModel):
prediction_heads[SEGMENTATION_HEATMAP] = self._make_prediction_net_list( prediction_heads[SEGMENTATION_HEATMAP] = self._make_prediction_net_list(
num_feature_outputs, num_feature_outputs,
num_classes, num_classes,
kernel_sizes=self._mask_params.mask_head_kernel_sizes,
num_filters=self._mask_params.mask_head_num_filters,
bias_fill=self._mask_params.heatmap_bias_init, bias_fill=self._mask_params.heatmap_bias_init,
name='seg_heatmap', name='seg_heatmap',
unit_height_conv=unit_height_conv) unit_height_conv=unit_height_conv)
......
...@@ -1539,7 +1539,9 @@ def get_fake_mask_params(): ...@@ -1539,7 +1539,9 @@ def get_fake_mask_params():
classification_loss=losses.WeightedSoftmaxClassificationLoss(), classification_loss=losses.WeightedSoftmaxClassificationLoss(),
task_loss_weight=1.0, task_loss_weight=1.0,
mask_height=4, mask_height=4,
mask_width=4) mask_width=4,
mask_head_num_filters=[96],
mask_head_kernel_sizes=[3])
def get_fake_densepose_params(): def get_fake_densepose_params():
......
...@@ -65,6 +65,14 @@ message CenterNet { ...@@ -65,6 +65,14 @@ message CenterNet {
// Localization loss configuration for object scale and offset losses. // Localization loss configuration for object scale and offset losses.
optional LocalizationLoss localization_loss = 8; optional LocalizationLoss localization_loss = 8;
// Parameters to determine the architecture of the object scale prediction
// head.
optional PredictionHeadParams scale_head_params = 9;
// Parameters to determine the architecture of the object offset prediction
// head.
optional PredictionHeadParams offset_head_params = 10;
} }
optional ObjectDetection object_detection_task = 4; optional ObjectDetection object_detection_task = 4;
...@@ -268,6 +276,10 @@ message CenterNet { ...@@ -268,6 +276,10 @@ message CenterNet {
// prediction head. -2.19 corresponds to predicting foreground with // prediction head. -2.19 corresponds to predicting foreground with
// a probability of 0.1. // a probability of 0.1.
optional float heatmap_bias_init = 3 [default = -2.19]; optional float heatmap_bias_init = 3 [default = -2.19];
// Parameters to determine the architecture of the segmentation mask
// prediction head.
optional PredictionHeadParams mask_head_params = 7;
} }
optional MaskEstimation mask_estimation_task = 8; optional MaskEstimation mask_estimation_task = 8;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment