Commit 6e5c5a1b authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Updated the logcis such that the CenterNet prediction head architectures are configurable.

PiperOrigin-RevId: 364731599
parent 5c3e08b7
...@@ -860,6 +860,25 @@ def keypoint_proto_to_params(kp_config, keypoint_map_dict): ...@@ -860,6 +860,25 @@ def keypoint_proto_to_params(kp_config, keypoint_map_dict):
for label, value in kp_config.keypoint_label_to_std.items(): for label, value in kp_config.keypoint_label_to_std.items():
keypoint_std_dev_dict[label] = value keypoint_std_dev_dict[label] = value
keypoint_std_dev = [keypoint_std_dev_dict[label] for label in keypoint_labels] keypoint_std_dev = [keypoint_std_dev_dict[label] for label in keypoint_labels]
if kp_config.HasField('heatmap_head_params'):
heatmap_head_num_filters = list(kp_config.heatmap_head_params.num_filters)
heatmap_head_kernel_sizes = list(kp_config.heatmap_head_params.kernel_sizes)
else:
heatmap_head_num_filters = [256]
heatmap_head_kernel_sizes = [3]
if kp_config.HasField('offset_head_params'):
offset_head_num_filters = list(kp_config.offset_head_params.num_filters)
offset_head_kernel_sizes = list(kp_config.offset_head_params.kernel_sizes)
else:
offset_head_num_filters = [256]
offset_head_kernel_sizes = [3]
if kp_config.HasField('regress_head_params'):
regress_head_num_filters = list(kp_config.regress_head_params.num_filters)
regress_head_kernel_sizes = list(
kp_config.regress_head_params.kernel_sizes)
else:
regress_head_num_filters = [256]
regress_head_kernel_sizes = [3]
return center_net_meta_arch.KeypointEstimationParams( return center_net_meta_arch.KeypointEstimationParams(
task_name=kp_config.task_name, task_name=kp_config.task_name,
class_id=label_map_item.id - CLASS_ID_OFFSET, class_id=label_map_item.id - CLASS_ID_OFFSET,
...@@ -888,7 +907,13 @@ def keypoint_proto_to_params(kp_config, keypoint_map_dict): ...@@ -888,7 +907,13 @@ def keypoint_proto_to_params(kp_config, keypoint_map_dict):
keypoint_depth_loss_weight=kp_config.keypoint_depth_loss_weight, keypoint_depth_loss_weight=kp_config.keypoint_depth_loss_weight,
score_distance_offset=kp_config.score_distance_offset, score_distance_offset=kp_config.score_distance_offset,
clip_out_of_frame_keypoints=kp_config.clip_out_of_frame_keypoints, clip_out_of_frame_keypoints=kp_config.clip_out_of_frame_keypoints,
rescore_instances=kp_config.rescore_instances) rescore_instances=kp_config.rescore_instances,
heatmap_head_num_filters=heatmap_head_num_filters,
heatmap_head_kernel_sizes=heatmap_head_kernel_sizes,
offset_head_num_filters=offset_head_num_filters,
offset_head_kernel_sizes=offset_head_kernel_sizes,
regress_head_num_filters=regress_head_num_filters,
regress_head_kernel_sizes=regress_head_kernel_sizes)
def object_detection_proto_to_params(od_config): def object_detection_proto_to_params(od_config):
...@@ -921,6 +946,13 @@ def object_center_proto_to_params(oc_config): ...@@ -921,6 +946,13 @@ def object_center_proto_to_params(oc_config):
keypoint_weights_for_center = [] keypoint_weights_for_center = []
if oc_config.keypoint_weights_for_center: if oc_config.keypoint_weights_for_center:
keypoint_weights_for_center = list(oc_config.keypoint_weights_for_center) keypoint_weights_for_center = list(oc_config.keypoint_weights_for_center)
if oc_config.center_head_params:
center_head_num_filters = list(oc_config.center_head_params.num_filters)
center_head_kernel_sizes = list(oc_config.center_head_params.kernel_sizes)
else:
center_head_num_filters = [256]
center_head_kernel_sizes = [3]
return center_net_meta_arch.ObjectCenterParams( return center_net_meta_arch.ObjectCenterParams(
classification_loss=classification_loss, classification_loss=classification_loss,
object_center_loss_weight=oc_config.object_center_loss_weight, object_center_loss_weight=oc_config.object_center_loss_weight,
...@@ -928,7 +960,9 @@ def object_center_proto_to_params(oc_config): ...@@ -928,7 +960,9 @@ def object_center_proto_to_params(oc_config):
min_box_overlap_iou=oc_config.min_box_overlap_iou, min_box_overlap_iou=oc_config.min_box_overlap_iou,
max_box_predictions=oc_config.max_box_predictions, max_box_predictions=oc_config.max_box_predictions,
use_labeled_classes=oc_config.use_labeled_classes, use_labeled_classes=oc_config.use_labeled_classes,
keypoint_weights_for_center=keypoint_weights_for_center) keypoint_weights_for_center=keypoint_weights_for_center,
center_head_num_filters=center_head_num_filters,
center_head_kernel_sizes=center_head_kernel_sizes)
def mask_proto_to_params(mask_config): def mask_proto_to_params(mask_config):
......
...@@ -120,6 +120,12 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest): ...@@ -120,6 +120,12 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
predict_depth: true predict_depth: true
per_keypoint_depth: true per_keypoint_depth: true
keypoint_depth_loss_weight: 0.3 keypoint_depth_loss_weight: 0.3
heatmap_head_params {
num_filters: 64
num_filters: 32
kernel_sizes: 5
kernel_sizes: 3
}
""" """
config = text_format.Merge(task_proto_txt, config = text_format.Merge(task_proto_txt,
center_net_pb2.CenterNet.KeypointEstimation()) center_net_pb2.CenterNet.KeypointEstimation())
...@@ -137,6 +143,12 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest): ...@@ -137,6 +143,12 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
beta: 4.0 beta: 4.0
} }
} }
center_head_params {
num_filters: 64
num_filters: 32
kernel_sizes: 5
kernel_sizes: 3
}
""" """
return text_format.Merge(proto_txt, return text_format.Merge(proto_txt,
center_net_pb2.CenterNet.ObjectCenterParams()) center_net_pb2.CenterNet.ObjectCenterParams())
...@@ -257,6 +269,8 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest): ...@@ -257,6 +269,8 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
self.assertAlmostEqual( self.assertAlmostEqual(
model._center_params.heatmap_bias_init, 3.14, places=4) model._center_params.heatmap_bias_init, 3.14, places=4)
self.assertEqual(model._center_params.max_box_predictions, 15) self.assertEqual(model._center_params.max_box_predictions, 15)
self.assertEqual(model._center_params.center_head_num_filters, [64, 32])
self.assertEqual(model._center_params.center_head_kernel_sizes, [5, 3])
# Check object detection related parameters. # Check object detection related parameters.
self.assertAlmostEqual(model._od_params.offset_loss_weight, 0.1) self.assertAlmostEqual(model._od_params.offset_loss_weight, 0.1)
...@@ -291,6 +305,12 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest): ...@@ -291,6 +305,12 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
self.assertEqual(kp_params.predict_depth, True) self.assertEqual(kp_params.predict_depth, True)
self.assertEqual(kp_params.per_keypoint_depth, True) self.assertEqual(kp_params.per_keypoint_depth, True)
self.assertAlmostEqual(kp_params.keypoint_depth_loss_weight, 0.3) self.assertAlmostEqual(kp_params.keypoint_depth_loss_weight, 0.3)
# Set by the config.
self.assertEqual(kp_params.heatmap_head_num_filters, [64, 32])
self.assertEqual(kp_params.heatmap_head_kernel_sizes, [5, 3])
# Default values:
self.assertEqual(kp_params.offset_head_num_filters, [256])
self.assertEqual(kp_params.offset_head_kernel_sizes, [3])
# Check mask related parameters. # Check mask related parameters.
self.assertAlmostEqual(model._mask_params.task_loss_weight, 0.7) self.assertAlmostEqual(model._mask_params.task_loss_weight, 0.7)
......
...@@ -137,7 +137,7 @@ class CenterNetFeatureExtractor(tf.keras.Model): ...@@ -137,7 +137,7 @@ class CenterNetFeatureExtractor(tf.keras.Model):
pass pass
def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256, def make_prediction_net(num_out_channels, kernel_sizes=(3), num_filters=(256),
bias_fill=None, use_depthwise=False, name=None): bias_fill=None, use_depthwise=False, name=None):
"""Creates a network to predict the given number of output channels. """Creates a network to predict the given number of output channels.
...@@ -146,8 +146,13 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256, ...@@ -146,8 +146,13 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
Args: Args:
num_out_channels: Number of output channels. num_out_channels: Number of output channels.
kernel_size: The size of the conv kernel in the intermediate layer kernel_sizes: A list representing the sizes of the conv kernel in the
num_filters: The number of filters in the intermediate conv layer. intermediate layer. Note that the length of the list indicates the number
of intermediate conv layers and it must be the same as the length of the
num_filters.
num_filters: A list representing the number of filters in the intermediate
conv layer. Note that the length of the list indicates the number of
intermediate conv layers.
bias_fill: If not None, is used to initialize the bias in the final conv bias_fill: If not None, is used to initialize the bias in the final conv
layer. layer.
use_depthwise: If true, use SeparableConv2D to construct the Sequential use_depthwise: If true, use SeparableConv2D to construct the Sequential
...@@ -159,6 +164,10 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256, ...@@ -159,6 +164,10 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
[batch_size, height, width, num_in_channels] returns an output [batch_size, height, width, num_in_channels] returns an output
of size [batch_size, height, width, num_out_channels] of size [batch_size, height, width, num_out_channels]
""" """
if isinstance(kernel_sizes, int) and isinstance(num_filters, int):
kernel_sizes = [kernel_sizes]
num_filters = [num_filters]
assert len(kernel_sizes) == len(num_filters)
if use_depthwise: if use_depthwise:
conv_fn = tf.keras.layers.SeparableConv2D conv_fn = tf.keras.layers.SeparableConv2D
else: else:
...@@ -175,16 +184,18 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256, ...@@ -175,16 +184,18 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
if bias_fill is not None: if bias_fill is not None:
out_conv.bias_initializer = tf.keras.initializers.constant(bias_fill) out_conv.bias_initializer = tf.keras.initializers.constant(bias_fill)
net = tf.keras.Sequential([ layers = []
conv_fn( for idx, (kernel_size,
num_filters, num_filter) in enumerate(zip(kernel_sizes, num_filters)):
kernel_size=kernel_size, layers.append(
padding='same', conv_fn(
name='conv2' if tf_version.is_tf1() else None), num_filter,
tf.keras.layers.ReLU(), out_conv kernel_size=kernel_size,
], padding='same',
name=name) name='conv2_%d' % idx if tf_version.is_tf1() else None))
layers.append(tf.keras.layers.ReLU())
layers.append(out_conv)
net = tf.keras.Sequential(layers, name=name)
return net return net
...@@ -1687,7 +1698,10 @@ class KeypointEstimationParams( ...@@ -1687,7 +1698,10 @@ class KeypointEstimationParams(
'offset_peak_radius', 'per_keypoint_offset', 'predict_depth', 'offset_peak_radius', 'per_keypoint_offset', 'predict_depth',
'per_keypoint_depth', 'keypoint_depth_loss_weight', 'per_keypoint_depth', 'keypoint_depth_loss_weight',
'score_distance_offset', 'clip_out_of_frame_keypoints', 'score_distance_offset', 'clip_out_of_frame_keypoints',
'rescore_instances' 'rescore_instances', 'heatmap_head_num_filters',
'heatmap_head_kernel_sizes', 'offset_head_num_filters',
'offset_head_kernel_sizes', 'regress_head_num_filters',
'regress_head_kernel_sizes'
])): ])):
"""Namedtuple to host object detection related parameters. """Namedtuple to host object detection related parameters.
...@@ -1726,7 +1740,13 @@ class KeypointEstimationParams( ...@@ -1726,7 +1740,13 @@ class KeypointEstimationParams(
keypoint_depth_loss_weight=1.0, keypoint_depth_loss_weight=1.0,
score_distance_offset=1e-6, score_distance_offset=1e-6,
clip_out_of_frame_keypoints=False, clip_out_of_frame_keypoints=False,
rescore_instances=False): rescore_instances=False,
heatmap_head_num_filters=(256),
heatmap_head_kernel_sizes=(3),
offset_head_num_filters=(256),
offset_head_kernel_sizes=(3),
regress_head_num_filters=(256),
regress_head_kernel_sizes=(3)):
"""Constructor with default values for KeypointEstimationParams. """Constructor with default values for KeypointEstimationParams.
Args: Args:
...@@ -1806,6 +1826,18 @@ class KeypointEstimationParams( ...@@ -1806,6 +1826,18 @@ class KeypointEstimationParams(
that are clipped have scores set to 0.0. that are clipped have scores set to 0.0.
rescore_instances: Whether to rescore instances based on a combination of rescore_instances: Whether to rescore instances based on a combination of
detection score and keypoint scores. detection score and keypoint scores.
heatmap_head_num_filters: filter numbers of the convolutional layers used
by the keypoint heatmap prediction head.
heatmap_head_kernel_sizes: kernel size of the convolutional layers used
by the keypoint heatmap prediction head.
offset_head_num_filters: filter numbers of the convolutional layers used
by the keypoint offset prediction head.
offset_head_kernel_sizes: kernel size of the convolutional layers used
by the keypoint offset prediction head.
regress_head_num_filters: filter numbers of the convolutional layers used
by the keypoint regression prediction head.
regress_head_kernel_sizes: kernel size of the convolutional layers used
by the keypoint regression prediction head.
Returns: Returns:
An initialized KeypointEstimationParams namedtuple. An initialized KeypointEstimationParams namedtuple.
...@@ -1820,14 +1852,18 @@ class KeypointEstimationParams( ...@@ -1820,14 +1852,18 @@ class KeypointEstimationParams(
candidate_search_scale, candidate_ranking_mode, offset_peak_radius, candidate_search_scale, candidate_ranking_mode, offset_peak_radius,
per_keypoint_offset, predict_depth, per_keypoint_depth, per_keypoint_offset, predict_depth, per_keypoint_depth,
keypoint_depth_loss_weight, score_distance_offset, keypoint_depth_loss_weight, score_distance_offset,
clip_out_of_frame_keypoints, rescore_instances) clip_out_of_frame_keypoints, rescore_instances,
heatmap_head_num_filters, heatmap_head_kernel_sizes,
offset_head_num_filters, offset_head_kernel_sizes,
regress_head_num_filters, regress_head_kernel_sizes)
class ObjectCenterParams( class ObjectCenterParams(
collections.namedtuple('ObjectCenterParams', [ collections.namedtuple('ObjectCenterParams', [
'classification_loss', 'object_center_loss_weight', 'heatmap_bias_init', 'classification_loss', 'object_center_loss_weight', 'heatmap_bias_init',
'min_box_overlap_iou', 'max_box_predictions', 'use_labeled_classes', 'min_box_overlap_iou', 'max_box_predictions', 'use_labeled_classes',
'keypoint_weights_for_center' 'keypoint_weights_for_center', 'center_head_num_filters',
'center_head_kernel_sizes'
])): ])):
"""Namedtuple to store object center prediction related parameters.""" """Namedtuple to store object center prediction related parameters."""
...@@ -1840,7 +1876,9 @@ class ObjectCenterParams( ...@@ -1840,7 +1876,9 @@ class ObjectCenterParams(
min_box_overlap_iou=0.7, min_box_overlap_iou=0.7,
max_box_predictions=100, max_box_predictions=100,
use_labeled_classes=False, use_labeled_classes=False,
keypoint_weights_for_center=None): keypoint_weights_for_center=None,
center_head_num_filters=(256),
center_head_kernel_sizes=(3)):
"""Constructor with default values for ObjectCenterParams. """Constructor with default values for ObjectCenterParams.
Args: Args:
...@@ -1861,7 +1899,10 @@ class ObjectCenterParams( ...@@ -1861,7 +1899,10 @@ class ObjectCenterParams(
center is calculated by the weighted mean of the keypoint locations. If center is calculated by the weighted mean of the keypoint locations. If
not provided, the object center is determined by the center of the not provided, the object center is determined by the center of the
bounding box (default behavior). bounding box (default behavior).
center_head_num_filters: filter numbers of the convolutional layers used
by the object center prediction head.
center_head_kernel_sizes: kernel size of the convolutional layers used
by the object center prediction head.
Returns: Returns:
An initialized ObjectCenterParams namedtuple. An initialized ObjectCenterParams namedtuple.
""" """
...@@ -1869,7 +1910,8 @@ class ObjectCenterParams( ...@@ -1869,7 +1910,8 @@ class ObjectCenterParams(
cls).__new__(cls, classification_loss, cls).__new__(cls, classification_loss,
object_center_loss_weight, heatmap_bias_init, object_center_loss_weight, heatmap_bias_init,
min_box_overlap_iou, max_box_predictions, min_box_overlap_iou, max_box_predictions,
use_labeled_classes, keypoint_weights_for_center) use_labeled_classes, keypoint_weights_for_center,
center_head_num_filters, center_head_kernel_sizes)
class MaskParams( class MaskParams(
...@@ -2194,14 +2236,14 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2194,14 +2236,14 @@ class CenterNetMetaArch(model.DetectionModel):
return self._batched_prediction_tensor_names return self._batched_prediction_tensor_names
def _make_prediction_net_list(self, num_feature_outputs, num_out_channels, def _make_prediction_net_list(self, num_feature_outputs, num_out_channels,
kernel_size=3, num_filters=256, bias_fill=None, kernel_sizes=(3), num_filters=(256),
name=None): bias_fill=None, name=None):
prediction_net_list = [] prediction_net_list = []
for i in range(num_feature_outputs): for i in range(num_feature_outputs):
prediction_net_list.append( prediction_net_list.append(
make_prediction_net( make_prediction_net(
num_out_channels, num_out_channels,
kernel_size=kernel_size, kernel_sizes=kernel_sizes,
num_filters=num_filters, num_filters=num_filters,
bias_fill=bias_fill, bias_fill=bias_fill,
use_depthwise=self._use_depthwise, use_depthwise=self._use_depthwise,
...@@ -2229,7 +2271,11 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2229,7 +2271,11 @@ class CenterNetMetaArch(model.DetectionModel):
""" """
prediction_heads = {} prediction_heads = {}
prediction_heads[OBJECT_CENTER] = self._make_prediction_net_list( prediction_heads[OBJECT_CENTER] = self._make_prediction_net_list(
num_feature_outputs, num_classes, bias_fill=class_prediction_bias_init, num_feature_outputs,
num_classes,
kernel_sizes=self._center_params.center_head_kernel_sizes,
num_filters=self._center_params.center_head_num_filters,
bias_fill=class_prediction_bias_init,
name='center') name='center')
if self._od_params is not None: if self._od_params is not None:
...@@ -2245,12 +2291,16 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2245,12 +2291,16 @@ class CenterNetMetaArch(model.DetectionModel):
task_name, KEYPOINT_HEATMAP)] = self._make_prediction_net_list( task_name, KEYPOINT_HEATMAP)] = self._make_prediction_net_list(
num_feature_outputs, num_feature_outputs,
num_keypoints, num_keypoints,
kernel_sizes=kp_params.heatmap_head_kernel_sizes,
num_filters=kp_params.heatmap_head_num_filters,
bias_fill=kp_params.heatmap_bias_init, bias_fill=kp_params.heatmap_bias_init,
name='kpt_heatmap') name='kpt_heatmap')
prediction_heads[get_keypoint_name( prediction_heads[get_keypoint_name(
task_name, KEYPOINT_REGRESSION)] = self._make_prediction_net_list( task_name, KEYPOINT_REGRESSION)] = self._make_prediction_net_list(
num_feature_outputs, num_feature_outputs,
NUM_OFFSET_CHANNELS * num_keypoints, NUM_OFFSET_CHANNELS * num_keypoints,
kernel_sizes=kp_params.regress_head_kernel_sizes,
num_filters=kp_params.regress_head_num_filters,
name='kpt_regress') name='kpt_regress')
if kp_params.per_keypoint_offset: if kp_params.per_keypoint_offset:
...@@ -2258,11 +2308,17 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2258,11 +2308,17 @@ class CenterNetMetaArch(model.DetectionModel):
task_name, KEYPOINT_OFFSET)] = self._make_prediction_net_list( task_name, KEYPOINT_OFFSET)] = self._make_prediction_net_list(
num_feature_outputs, num_feature_outputs,
NUM_OFFSET_CHANNELS * num_keypoints, NUM_OFFSET_CHANNELS * num_keypoints,
kernel_sizes=kp_params.offset_head_kernel_sizes,
num_filters=kp_params.offset_head_num_filters,
name='kpt_offset') name='kpt_offset')
else: else:
prediction_heads[get_keypoint_name( prediction_heads[get_keypoint_name(
task_name, KEYPOINT_OFFSET)] = self._make_prediction_net_list( task_name, KEYPOINT_OFFSET)] = self._make_prediction_net_list(
num_feature_outputs, NUM_OFFSET_CHANNELS, name='kpt_offset') num_feature_outputs,
NUM_OFFSET_CHANNELS,
kernel_sizes=kp_params.offset_head_kernel_sizes,
num_filters=kp_params.offset_head_num_filters,
name='kpt_offset')
if kp_params.predict_depth: if kp_params.predict_depth:
num_depth_channel = ( num_depth_channel = (
......
...@@ -1482,7 +1482,9 @@ def get_fake_center_params(max_box_predictions=5): ...@@ -1482,7 +1482,9 @@ def get_fake_center_params(max_box_predictions=5):
object_center_loss_weight=1.0, object_center_loss_weight=1.0,
min_box_overlap_iou=1.0, min_box_overlap_iou=1.0,
max_box_predictions=max_box_predictions, max_box_predictions=max_box_predictions,
use_labeled_classes=False) use_labeled_classes=False,
center_head_num_filters=[128],
center_head_kernel_sizes=[5])
def get_fake_od_params(): def get_fake_od_params():
......
...@@ -30,6 +30,23 @@ message CenterNet { ...@@ -30,6 +30,23 @@ message CenterNet {
// TODO(b/170989061) When bug is fixed, make this the default behavior. // TODO(b/170989061) When bug is fixed, make this the default behavior.
optional bool compute_heatmap_sparse = 15 [default = false]; optional bool compute_heatmap_sparse = 15 [default = false];
// Parameters to determine the model architecture/layers of the prediction
// heads.
message PredictionHeadParams {
// The two fields: num_filters, kernel_sizes correspond to the parameters of
// the convolutional layers used by the prediction head. If provided, the
// length of the two repeated fields need to be the same and represents the
// number of convolutional layers.
// Corresponds to the "filters" argument in tf.keras.layers.Conv2D. If not
// provided, the default value [256] will be used.
repeated int32 num_filters = 1;
// Corresponds to the "kernel_size" argument in tf.keras.layers.Conv2D. If
// not provided, the default value [3] will be used.
repeated int32 kernel_sizes = 2;
}
// Parameters which are related to object detection task. // Parameters which are related to object detection task.
message ObjectDetection { message ObjectDetection {
// The original fields are moved to ObjectCenterParams or deleted. // The original fields are moved to ObjectCenterParams or deleted.
...@@ -81,6 +98,10 @@ message CenterNet { ...@@ -81,6 +98,10 @@ message CenterNet {
// object center is determined by the bounding box groundtruth annotations // object center is determined by the bounding box groundtruth annotations
// (default behavior). // (default behavior).
repeated float keypoint_weights_for_center = 7; repeated float keypoint_weights_for_center = 7;
// Parameters to determine the architecture of the object center prediction
// head.
optional PredictionHeadParams center_head_params = 8;
} }
optional ObjectCenterParams object_center_params = 5; optional ObjectCenterParams object_center_params = 5;
...@@ -207,6 +228,18 @@ message CenterNet { ...@@ -207,6 +228,18 @@ message CenterNet {
// where o is the object score, s_i is the score for keypoint i, and k is // where o is the object score, s_i is the score for keypoint i, and k is
// the number of keypoints for that class. // the number of keypoints for that class.
optional bool rescore_instances = 24 [default = false]; optional bool rescore_instances = 24 [default = false];
// Parameters to determine the architecture of the keypoint heatmap
// prediction head.
optional PredictionHeadParams heatmap_head_params = 25;
// Parameters to determine the architecture of the keypoint offset
// prediction head.
optional PredictionHeadParams offset_head_params = 26;
// Parameters to determine the architecture of the keypoint regression
// prediction head.
optional PredictionHeadParams regress_head_params = 27;
} }
repeated KeypointEstimation keypoint_estimation_task = 7; repeated KeypointEstimation keypoint_estimation_task = 7;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment