Commit fe816dc2 authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Updated the target assigner and centernet meta arch such that it supports the

use case which uses the keypoint location to determine the object center
location and heatmap.

PiperOrigin-RevId: 363522414
parent f84103f8
......@@ -915,13 +915,17 @@ def object_center_proto_to_params(oc_config):
losses_pb2.WeightedL2LocalizationLoss())
loss.classification_loss.CopyFrom(oc_config.classification_loss)
classification_loss, _, _, _, _, _, _ = (losses_builder.build(loss))
keypoint_weights_for_center = []
if oc_config.keypoint_weights_for_center:
keypoint_weights_for_center = list(oc_config.keypoint_weights_for_center)
return center_net_meta_arch.ObjectCenterParams(
classification_loss=classification_loss,
object_center_loss_weight=oc_config.object_center_loss_weight,
heatmap_bias_init=oc_config.heatmap_bias_init,
min_box_overlap_iou=oc_config.min_box_overlap_iou,
max_box_predictions=oc_config.max_box_predictions,
use_labeled_classes=oc_config.use_labeled_classes)
use_labeled_classes=oc_config.use_labeled_classes,
keypoint_weights_for_center=keypoint_weights_for_center)
def mask_proto_to_params(mask_config):
......
......@@ -141,6 +141,26 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
return text_format.Merge(proto_txt,
center_net_pb2.CenterNet.ObjectCenterParams())
def get_fake_object_center_from_keypoints_proto(self):
proto_txt = """
object_center_loss_weight: 0.5
heatmap_bias_init: 3.14
min_box_overlap_iou: 0.2
max_box_predictions: 15
classification_loss {
penalty_reduced_logistic_focal_loss {
alpha: 3.0
beta: 4.0
}
}
keypoint_weights_for_center: 1.0
keypoint_weights_for_center: 0.0
keypoint_weights_for_center: 1.0
keypoint_weights_for_center: 0.0
"""
return text_format.Merge(proto_txt,
center_net_pb2.CenterNet.ObjectCenterParams())
def get_fake_object_detection_proto(self):
proto_txt = """
task_loss_weight: 0.5
......@@ -308,6 +328,50 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
self.assertIsInstance(backbone, hourglass_network.HourglassNetwork)
self.assertTrue(backbone.num_hourglasses, 1)
def test_create_center_net_model_from_keypoints(self):
"""Test building a CenterNet model from proto txt."""
proto_txt = """
center_net {
num_classes: 10
feature_extractor {
type: "hourglass_52"
channel_stds: [4, 5, 6]
bgr_ordering: true
}
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 512
max_dimension: 512
pad_to_max_dimension: true
}
}
}
"""
# Set up the configuration proto.
config = text_format.Merge(proto_txt, model_pb2.DetectionModel())
# Only add object center and keypoint estimation configs here.
config.center_net.object_center_params.CopyFrom(
self.get_fake_object_center_from_keypoints_proto())
config.center_net.keypoint_estimation_task.append(
self.get_fake_keypoint_proto())
config.center_net.keypoint_label_map_path = (
self.get_fake_label_map_file_path())
# Build the model from the configuration.
model = model_builder.build(config, is_training=True)
# Check object center related parameters.
self.assertEqual(model._num_classes, 10)
self.assertEqual(model._center_params.keypoint_weights_for_center,
[1.0, 0.0, 1.0, 0.0])
# Check keypoint estimation related parameters.
kp_params = model._kp_params_dict['human_pose']
self.assertAlmostEqual(kp_params.task_loss_weight, 0.9)
self.assertEqual(kp_params.keypoint_indices, [0, 1, 2, 3])
self.assertEqual(kp_params.keypoint_labels,
['nose', 'left_shoulder', 'right_shoulder', 'hip'])
if __name__ == '__main__':
tf.test.main()
......@@ -837,10 +837,82 @@ def _compute_std_dev_from_box_size(boxes_height, boxes_width, min_overlap):
return sigma
def _preprocess_keypoints_and_weights(out_height, out_width, keypoints,
class_onehot, class_weights,
keypoint_weights, class_id,
keypoint_indices):
"""Preprocesses the keypoints and the corresponding keypoint weights.
This function performs several common steps to preprocess the keypoints and
keypoint weights features, including:
1) Select the subset of keypoints based on the keypoint indices, fill the
keypoint NaN values with zeros and convert to absolute coordinates.
2) Generate the weights of the keypoint using the following information:
a. The class of the instance.
b. The NaN value of the keypoint coordinates.
c. The provided keypoint weights.
Args:
out_height: An integer or an integer tensor indicating the output height
of the model.
out_width: An integer or an integer tensor indicating the output width of
the model.
keypoints: A float tensor of shape [num_instances, num_total_keypoints, 2]
representing the original keypoint grountruth coordinates.
class_onehot: A float tensor of shape [num_instances, num_classes]
containing the class targets with the 0th index assumed to map to the
first non-background class.
class_weights: A float tensor of shape [num_instances] containing weights
for groundtruth instances.
keypoint_weights: A float tensor of shape
[num_instances, num_total_keypoints] representing the weights of each
keypoints.
class_id: int, the ID of the class (0-indexed) that contains the target
keypoints to consider in this task.
keypoint_indices: A list of integers representing the indices of the
keypoints to be considered in this task. This is used to retrieve the
subset of the keypoints that should be considered in this task.
Returns:
A tuple of two tensors:
keypoint_absolute: A float tensor of shape
[num_instances, num_keypoints, 2] which is the selected and updated
keypoint coordinates.
keypoint_weights: A float tensor of shape [num_instances, num_keypoints]
representing the updated weight of each keypoint.
"""
# Select the targets keypoints by their type ids and generate the mask
# of valid elements.
valid_mask, keypoints = ta_utils.get_valid_keypoint_mask_for_class(
keypoint_coordinates=keypoints,
class_id=class_id,
class_onehot=class_onehot,
class_weights=class_weights,
keypoint_indices=keypoint_indices)
# Keypoint coordinates in absolute coordinate system.
# The shape of the tensors: [num_instances, num_keypoints, 2].
keypoints_absolute = keypoint_ops.to_absolute_coordinates(
keypoints, out_height, out_width)
# Assign default weights for the keypoints.
if keypoint_weights is None:
keypoint_weights = tf.ones_like(keypoints[:, :, 0])
else:
keypoint_weights = tf.gather(
keypoint_weights, indices=keypoint_indices, axis=1)
keypoint_weights = keypoint_weights * valid_mask
return keypoints_absolute, keypoint_weights
class CenterNetCenterHeatmapTargetAssigner(object):
"""Wrapper to compute the object center heatmap."""
def __init__(self, stride, min_overlap=0.7, compute_heatmap_sparse=False):
def __init__(self,
stride,
min_overlap=0.7,
compute_heatmap_sparse=False,
keypoint_class_id=None,
keypoint_indices=None,
keypoint_weights_for_center=None):
"""Initializes the target assigner.
Args:
......@@ -851,11 +923,25 @@ class CenterNetCenterHeatmapTargetAssigner(object):
version of the Op that computes the heatmap. The sparse version scales
better with number of classes, but in some cases is known to cause
OOM error. See (b/170989061).
keypoint_class_id: int, the ID of the class (0-indexed) that contains the
target keypoints to consider in this task.
keypoint_indices: A list of integers representing the indices of the
keypoints to be considered in this task. This is used to retrieve the
subset of the keypoints from gt_keypoints that should be considered in
this task.
keypoint_weights_for_center: The keypoint weights used for calculating the
location of object center. The number of weights need to be the same as
the number of keypoints. The object center is calculated by the weighted
mean of the keypoint locations. If not provided, the object center is
determined by the center of the bounding box (default behavior).
"""
self._stride = stride
self._min_overlap = min_overlap
self._compute_heatmap_sparse = compute_heatmap_sparse
self._keypoint_class_id = keypoint_class_id
self._keypoint_indices = keypoint_indices
self._keypoint_weights_for_center = keypoint_weights_for_center
def assign_center_targets_from_boxes(self,
height,
......@@ -927,6 +1013,145 @@ class CenterNetCenterHeatmapTargetAssigner(object):
# Return the stacked heatmaps over the batch.
return tf.stack(heatmaps, axis=0)
def assign_center_targets_from_keypoints(self,
height,
width,
gt_classes_list,
gt_keypoints_list,
gt_weights_list=None,
gt_keypoints_weights_list=None):
"""Computes the object center heatmap target using keypoint locations.
Args:
height: int, height of input to the model. This is used to
determine the height of the output.
width: int, width of the input to the model. This is used to
determine the width of the output.
gt_classes_list: A list of float tensors with shape [num_boxes,
num_classes] representing the one-hot encoded class labels for each box
in the gt_boxes_list.
gt_keypoints_list: A list of float tensors with shape [num_boxes, 4]
representing the groundtruth detection bounding boxes for each sample in
the batch. The box coordinates are expected in normalized coordinates.
gt_weights_list: A list of float tensors with shape [num_boxes]
representing the weight of each groundtruth detection box.
gt_keypoints_weights_list: [Optional] a list of 3D tf.float32 tensors of
shape [num_instances, num_total_keypoints] representing the weights of
each keypoints. If not provided, then all not NaN keypoints will be
equally weighted.
Returns:
heatmap: A Tensor of size [batch_size, output_height, output_width,
num_classes] representing the per class center heatmap. output_height
and output_width are computed by dividing the input height and width by
the stride specified during initialization.
"""
assert (self._keypoint_weights_for_center is not None and
self._keypoint_class_id is not None and
self._keypoint_indices is not None)
out_height = tf.cast(height // self._stride, tf.float32)
out_width = tf.cast(width // self._stride, tf.float32)
# Compute the yx-grid to be used to generate the heatmap. Each returned
# tensor has shape of [out_height, out_width]
(y_grid, x_grid) = ta_utils.image_shape_to_grids(out_height, out_width)
heatmaps = []
if gt_weights_list is None:
gt_weights_list = [None] * len(gt_classes_list)
if gt_keypoints_weights_list is None:
gt_keypoints_weights_list = [None] * len(gt_keypoints_list)
for keypoints, classes, kp_weights, weights in zip(
gt_keypoints_list, gt_classes_list, gt_keypoints_weights_list,
gt_weights_list):
keypoints_absolute, kp_weights = _preprocess_keypoints_and_weights(
out_height=out_height,
out_width=out_width,
keypoints=keypoints,
class_onehot=classes,
class_weights=weights,
keypoint_weights=kp_weights,
class_id=self._keypoint_class_id,
keypoint_indices=self._keypoint_indices)
# _, num_keypoints, _ = (
# shape_utils.combined_static_and_dynamic_shape(keypoints_absolute))
# Update the keypoint weights by the specified keypoints weights.
kp_loc_weights = tf.constant(
self._keypoint_weights_for_center, dtype=tf.float32)
updated_kp_weights = kp_weights * kp_loc_weights[tf.newaxis, :]
# Obtain the sum of the weights for each instance.
# instance_weight_sum has shape: [num_instance].
instance_weight_sum = tf.reduce_sum(updated_kp_weights, axis=1)
# Weight the keypoint coordinates by updated_kp_weights.
# weighted_keypoints has shape: [num_instance, num_keypoints, 2]
weighted_keypoints = keypoints_absolute * tf.expand_dims(
updated_kp_weights, axis=2)
# Compute the mean of the keypoint coordinates over the weighted
# keypoints.
# keypoint_mean has shape: [num_instance, 2]
keypoint_mean = tf.math.divide(
tf.reduce_sum(weighted_keypoints, axis=1),
tf.expand_dims(instance_weight_sum, axis=-1))
# Replace the NaN values (due to divided by zeros in the above operation)
# by 0.0 where the sum of instance weight is zero.
# keypoint_mean has shape: [num_instance, 2]
keypoint_mean = tf.where(
tf.stack([instance_weight_sum, instance_weight_sum], axis=1) > 0.0,
keypoint_mean, tf.zeros_like(keypoint_mean))
# Compute the distance from each keypoint to the mean location using
# broadcasting and weighted by updated_kp_weights.
# keypoint_dist has shape: [num_instance, num_keypoints]
keypoint_mean = tf.expand_dims(keypoint_mean, axis=1)
keypoint_dist = tf.math.sqrt(
tf.reduce_sum(
tf.math.square(keypoints_absolute - keypoint_mean), axis=2))
keypoint_dist = keypoint_dist * updated_kp_weights
# Compute the average of the distances from each keypoint to the mean
# location and update the average value by zero when the instance weight
# is zero.
# avg_radius has shape: [num_instance]
avg_radius = tf.math.divide(
tf.reduce_sum(keypoint_dist, axis=1), instance_weight_sum)
avg_radius = tf.where(
instance_weight_sum > 0.0, avg_radius, tf.zeros_like(avg_radius))
# Update the class instance weight. If the instance doesn't contain enough
# valid keypoint values (i.e. instance_weight_sum == 0.0), then set the
# instance weight to zero.
# updated_class_weights has shape: [num_instance]
updated_class_weights = tf.where(
instance_weight_sum > 0.0, weights, tf.zeros_like(weights))
# Compute the sigma from average distance. We use 2 * average distance to
# to approximate the width/height of the bounding box.
# sigma has shape: [num_instances].
sigma = _compute_std_dev_from_box_size(2 * avg_radius, 2 * avg_radius,
self._min_overlap)
# Apply the Gaussian kernel to the center coordinates. Returned heatmap
# has shape of [out_height, out_width, num_classes]
heatmap = ta_utils.coordinates_to_heatmap(
y_grid=y_grid,
x_grid=x_grid,
y_coordinates=keypoint_mean[:, 0, 0],
x_coordinates=keypoint_mean[:, 0, 1],
sigma=sigma,
channel_onehot=classes,
channel_weights=updated_class_weights,
sparse=self._compute_heatmap_sparse)
heatmaps.append(heatmap)
# Return the stacked heatmaps over the batch.
return tf.stack(heatmaps, axis=0)
class CenterNetBoxTargetAssigner(object):
"""Wrapper to compute target tensors for the object detection task.
......@@ -1126,65 +1351,6 @@ class CenterNetKeypointTargetAssigner(object):
assert len(keypoint_indices) == len(keypoint_std_dev)
self._keypoint_std_dev = keypoint_std_dev
def _preprocess_keypoints_and_weights(self, out_height, out_width, keypoints,
class_onehot, class_weights,
keypoint_weights):
"""Preprocesses the keypoints and the corresponding keypoint weights.
This function performs several common steps to preprocess the keypoints and
keypoint weights features, including:
1) Select the subset of keypoints based on the keypoint indices, fill the
keypoint NaN values with zeros and convert to absoluate coordinates.
2) Generate the weights of the keypoint using the following information:
a. The class of the instance.
b. The NaN value of the keypoint coordinates.
c. The provided keypoint weights.
Args:
out_height: An integer or an interger tensor indicating the output height
of the model.
out_width: An integer or an interger tensor indicating the output width of
the model.
keypoints: A float tensor of shape [num_instances, num_total_keypoints, 2]
representing the original keypoint grountruth coordinates.
class_onehot: A float tensor of shape [num_instances, num_classes]
containing the class targets with the 0th index assumed to map to the
first non-background class.
class_weights: A float tensor of shape [num_instances] containing weights
for groundtruth instances.
keypoint_weights: A float tensor of shape
[num_instances, num_total_keypoints] representing the weights of each
keypoints.
Returns:
A tuple of two tensors:
keypoint_absolute: A float tensor of shape
[num_instances, num_keypoints, 2] which is the selected and updated
keypoint coordinates.
keypoint_weights: A float tensor of shape [num_instances, num_keypoints]
representing the updated weight of each keypoint.
"""
# Select the targets keypoints by their type ids and generate the mask
# of valid elements.
valid_mask, keypoints = ta_utils.get_valid_keypoint_mask_for_class(
keypoint_coordinates=keypoints,
class_id=self._class_id,
class_onehot=class_onehot,
class_weights=class_weights,
keypoint_indices=self._keypoint_indices)
# Keypoint coordinates in absolute coordinate system.
# The shape of the tensors: [num_instances, num_keypoints, 2].
keypoints_absolute = keypoint_ops.to_absolute_coordinates(
keypoints, out_height, out_width)
# Assign default weights for the keypoints.
if keypoint_weights is None:
keypoint_weights = tf.ones_like(keypoints[:, :, 0])
else:
keypoint_weights = tf.gather(
keypoint_weights, indices=self._keypoint_indices, axis=1)
keypoint_weights = keypoint_weights * valid_mask
return keypoints_absolute, keypoint_weights
def assign_keypoint_heatmap_targets(self,
height,
width,
......@@ -1245,13 +1411,15 @@ class CenterNetKeypointTargetAssigner(object):
for keypoints, classes, kp_weights, weights, boxes in zip(
gt_keypoints_list, gt_classes_list, gt_keypoints_weights_list,
gt_weights_list, gt_boxes_list):
keypoints_absolute, kp_weights = self._preprocess_keypoints_and_weights(
keypoints_absolute, kp_weights = _preprocess_keypoints_and_weights(
out_height=out_height,
out_width=out_width,
keypoints=keypoints,
class_onehot=classes,
class_weights=weights,
keypoint_weights=kp_weights)
keypoint_weights=kp_weights,
class_id=self._class_id,
keypoint_indices=self._keypoint_indices)
num_instances, num_keypoints, _ = (
shape_utils.combined_static_and_dynamic_shape(keypoints_absolute))
......@@ -1399,13 +1567,15 @@ class CenterNetKeypointTargetAssigner(object):
for i, (keypoints, classes, kp_weights, weights) in enumerate(
zip(gt_keypoints_list, gt_classes_list, gt_keypoints_weights_list,
gt_weights_list)):
keypoints_absolute, kp_weights = self._preprocess_keypoints_and_weights(
keypoints_absolute, kp_weights = _preprocess_keypoints_and_weights(
out_height=height // self._stride,
out_width=width // self._stride,
keypoints=keypoints,
class_onehot=classes,
class_weights=weights,
keypoint_weights=kp_weights)
keypoint_weights=kp_weights,
class_id=self._class_id,
keypoint_indices=self._keypoint_indices)
num_instances, num_keypoints, _ = (
shape_utils.combined_static_and_dynamic_shape(keypoints_absolute))
......@@ -1532,13 +1702,15 @@ class CenterNetKeypointTargetAssigner(object):
zip(gt_keypoints_list, gt_classes_list,
gt_keypoints_weights_list, gt_weights_list,
gt_keypoint_depths_list, gt_keypoint_depth_weights_list)):
keypoints_absolute, kp_weights = self._preprocess_keypoints_and_weights(
keypoints_absolute, kp_weights = _preprocess_keypoints_and_weights(
out_height=height // self._stride,
out_width=width // self._stride,
keypoints=keypoints,
class_onehot=classes,
class_weights=weights,
keypoint_weights=kp_weights)
keypoint_weights=kp_weights,
class_id=self._class_id,
keypoint_indices=self._keypoint_indices)
num_instances, num_keypoints, _ = (
shape_utils.combined_static_and_dynamic_shape(keypoints_absolute))
......@@ -1702,13 +1874,15 @@ class CenterNetKeypointTargetAssigner(object):
for i, (keypoints, classes, boxes, kp_weights, weights) in enumerate(
zip(gt_keypoints_list, gt_classes_list,
gt_boxes_list, gt_keypoints_weights_list, gt_weights_list)):
keypoints_absolute, kp_weights = self._preprocess_keypoints_and_weights(
keypoints_absolute, kp_weights = _preprocess_keypoints_and_weights(
out_height=height // self._stride,
out_width=width // self._stride,
keypoints=keypoints,
class_onehot=classes,
class_weights=weights,
keypoint_weights=kp_weights)
keypoint_weights=kp_weights,
class_id=self._class_id,
keypoint_indices=self._keypoint_indices)
num_instances, num_keypoints, _ = (
shape_utils.combined_static_and_dynamic_shape(keypoints_absolute))
......
......@@ -14,6 +14,7 @@
# ==============================================================================
"""Tests for object_detection.core.target_assigner."""
from absl.testing import parameterized
import numpy as np
import tensorflow.compat.v1 as tf
......@@ -1235,7 +1236,8 @@ def _array_argmax(array):
return np.unravel_index(np.argmax(array), array.shape)
class CenterNetCenterHeatmapTargetAssignerTest(test_case.TestCase):
class CenterNetCenterHeatmapTargetAssignerTest(test_case.TestCase,
parameterized.TestCase):
def setUp(self):
super(CenterNetCenterHeatmapTargetAssignerTest, self).setUp()
......@@ -1263,6 +1265,66 @@ class CenterNetCenterHeatmapTargetAssignerTest(test_case.TestCase):
self.assertEqual((15, 5), _array_argmax(targets[0, :, :, 1]))
self.assertAlmostEqual(1.0, targets[0, 15, 5, 1])
@parameterized.parameters(
{'keypoint_weights_for_center': [1.0, 1.0, 1.0, 1.0]},
{'keypoint_weights_for_center': [0.0, 0.0, 1.0, 1.0]},
)
def test_center_location_by_keypoints(self, keypoint_weights_for_center):
"""Test that the centers are at the correct location."""
kpts_y = [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.0, 0.0, 0.0, 0.0]]
kpts_x = [[0.5, 0.6, 0.7, 0.8], [0.1, 0.2, 0.3, 0.4], [0.0, 0.0, 0.0, 0.0]]
gt_keypoints_list = [
tf.stack([tf.constant(kpts_y), tf.constant(kpts_x)], axis=2)
]
kpts_weight = [[1.0, 1.0, 1.0, 1.0], [1.0, 0.0, 1.0, 0.0],
[1.0, 0.0, 1.0, 0.0]]
gt_keypoints_weights_list = [tf.constant(kpts_weight)]
gt_classes_list = [
tf.one_hot([0, 0, 0], depth=1),
]
gt_weights_list = [tf.constant([1.0, 1.0, 0.0])]
def graph_fn():
assigner = targetassigner.CenterNetCenterHeatmapTargetAssigner(
4,
keypoint_class_id=0,
keypoint_indices=[0, 1, 2, 3],
keypoint_weights_for_center=keypoint_weights_for_center)
targets = assigner.assign_center_targets_from_keypoints(
80,
80,
gt_classes_list=gt_classes_list,
gt_keypoints_list=gt_keypoints_list,
gt_weights_list=gt_weights_list,
gt_keypoints_weights_list=gt_keypoints_weights_list)
return targets
targets = self.execute(graph_fn, [])
if sum(keypoint_weights_for_center) == 4.0:
# There should be two peaks at location (5, 13), and (12, 4).
# (5, 13) = ((0.1 + 0.2 + 0.3 + 0.4) / 4 * 80 / 4,
# (0.5 + 0.6 + 0.7 + 0.8) / 4 * 80 / 4)
# (12, 4) = ((0.5 + 0.7) / 2 * 80 / 4,
# (0.1 + 0.3) / 2 * 80 / 4)
self.assertEqual((5, 13), _array_argmax(targets[0, :, :, 0]))
self.assertAlmostEqual(1.0, targets[0, 5, 13, 0])
self.assertEqual((1, 20, 20, 1), targets.shape)
targets[0, 5, 13, 0] = 0.0
self.assertEqual((12, 4), _array_argmax(targets[0, :, :, 0]))
self.assertAlmostEqual(1.0, targets[0, 12, 4, 0])
else:
# There should be two peaks at location (5, 13), and (12, 4).
# (7, 15) = ((0.3 + 0.4) / 2 * 80 / 4,
# (0.7 + 0.8) / 2 * 80 / 4)
# (14, 6) = (0.7 * 80 / 4, 0.3 * 80 / 4)
self.assertEqual((7, 15), _array_argmax(targets[0, :, :, 0]))
self.assertAlmostEqual(1.0, targets[0, 7, 15, 0])
self.assertEqual((1, 20, 20, 1), targets.shape)
targets[0, 7, 15, 0] = 0.0
self.assertEqual((14, 6), _array_argmax(targets[0, :, :, 0]))
self.assertAlmostEqual(1.0, targets[0, 14, 6, 0])
def test_center_batch_shape(self):
"""Test that the shape of the target for a batch is correct."""
def graph_fn():
......
......@@ -1730,7 +1730,8 @@ class KeypointEstimationParams(
class ObjectCenterParams(
collections.namedtuple('ObjectCenterParams', [
'classification_loss', 'object_center_loss_weight', 'heatmap_bias_init',
'min_box_overlap_iou', 'max_box_predictions', 'use_only_known_classes'
'min_box_overlap_iou', 'max_box_predictions', 'use_labeled_classes',
'keypoint_weights_for_center'
])):
"""Namedtuple to store object center prediction related parameters."""
......@@ -1742,7 +1743,8 @@ class ObjectCenterParams(
heatmap_bias_init=-2.19,
min_box_overlap_iou=0.7,
max_box_predictions=100,
use_labeled_classes=False):
use_labeled_classes=False,
keypoint_weights_for_center=None):
"""Constructor with default values for ObjectCenterParams.
Args:
......@@ -1757,6 +1759,12 @@ class ObjectCenterParams(
computing the class specific center heatmaps.
max_box_predictions: int, the maximum number of boxes to predict.
use_labeled_classes: boolean, compute the loss only labeled classes.
keypoint_weights_for_center: (optional) The keypoint weights used for
calculating the location of object center. If provided, the number of
weights need to be the same as the number of keypoints. The object
center is calculated by the weighted mean of the keypoint locations. If
not provided, the object center is determined by the center of the
bounding box (default behavior).
Returns:
An initialized ObjectCenterParams namedtuple.
......@@ -1765,7 +1773,7 @@ class ObjectCenterParams(
cls).__new__(cls, classification_loss,
object_center_loss_weight, heatmap_bias_init,
min_box_overlap_iou, max_box_predictions,
use_labeled_classes)
use_labeled_classes, keypoint_weights_for_center)
class MaskParams(
......@@ -2224,9 +2232,31 @@ class CenterNetMetaArch(model.DetectionModel):
A dictionary of initialized target assigners for each task.
"""
target_assigners = {}
target_assigners[OBJECT_CENTER] = (
cn_assigner.CenterNetCenterHeatmapTargetAssigner(
stride, min_box_overlap_iou, self._compute_heatmap_sparse))
keypoint_weights_for_center = (
self._center_params.keypoint_weights_for_center)
if not keypoint_weights_for_center:
target_assigners[OBJECT_CENTER] = (
cn_assigner.CenterNetCenterHeatmapTargetAssigner(
stride, min_box_overlap_iou, self._compute_heatmap_sparse))
self._center_from_keypoints = False
else:
# Determining the object center location by keypoint location is only
# supported when there is exactly one keypoint prediction task and no
# object detection task is specified.
assert len(self._kp_params_dict) == 1 and self._od_params is None
kp_params = next(iter(self._kp_params_dict.values()))
# The number of keypoint_weights_for_center needs to be the same as the
# number of keypoints.
assert len(keypoint_weights_for_center) == len(kp_params.keypoint_indices)
target_assigners[OBJECT_CENTER] = (
cn_assigner.CenterNetCenterHeatmapTargetAssigner(
stride,
min_box_overlap_iou,
self._compute_heatmap_sparse,
keypoint_class_id=kp_params.class_id,
keypoint_indices=kp_params.keypoint_indices,
keypoint_weights_for_center=keypoint_weights_for_center))
self._center_from_keypoints = True
if self._od_params is not None:
target_assigners[DETECTION_TASK] = (
cn_assigner.CenterNetBoxTargetAssigner(stride))
......@@ -2275,11 +2305,10 @@ class CenterNetMetaArch(model.DetectionModel):
Returns:
A float scalar tensor representing the object center loss per instance.
"""
gt_boxes_list = self.groundtruth_lists(fields.BoxListFields.boxes)
gt_classes_list = self.groundtruth_lists(fields.BoxListFields.classes)
gt_weights_list = self.groundtruth_lists(fields.BoxListFields.weights)
if self._center_params.use_only_known_classes:
if self._center_params.use_labeled_classes:
gt_labeled_classes_list = self.groundtruth_lists(
fields.InputDataFields.groundtruth_labeled_classes)
batch_labeled_classes = tf.stack(gt_labeled_classes_list, axis=0)
......@@ -2291,12 +2320,22 @@ class CenterNetMetaArch(model.DetectionModel):
# Convert the groundtruth to targets.
assigner = self._target_assigner_dict[OBJECT_CENTER]
heatmap_targets = assigner.assign_center_targets_from_boxes(
height=input_height,
width=input_width,
gt_boxes_list=gt_boxes_list,
gt_classes_list=gt_classes_list,
gt_weights_list=gt_weights_list)
if self._center_from_keypoints:
gt_keypoints_list = self.groundtruth_lists(fields.BoxListFields.keypoints)
heatmap_targets = assigner.assign_center_targets_from_keypoints(
height=input_height,
width=input_width,
gt_classes_list=gt_classes_list,
gt_keypoints_list=gt_keypoints_list,
gt_weights_list=gt_weights_list)
else:
gt_boxes_list = self.groundtruth_lists(fields.BoxListFields.boxes)
heatmap_targets = assigner.assign_center_targets_from_boxes(
height=input_height,
width=input_width,
gt_boxes_list=gt_boxes_list,
gt_classes_list=gt_classes_list,
gt_weights_list=gt_weights_list)
flattened_heatmap_targets = _flatten_spatial_dimensions(heatmap_targets)
num_boxes = _to_float32(get_num_instances_from_weights(gt_weights_list))
......
......@@ -2108,7 +2108,7 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
# def graph_fn():
detections = model.postprocess(prediction_dict,
tf.constant([[128, 128, 3]]))
# return detections
# return detections
# detections = self.execute_cpu(graph_fn, [])
self.assertAllClose(detections['detection_scores'][0],
......@@ -2716,17 +2716,16 @@ class CenterNetMetaComputeLossTest(test_case.TestCase, parameterized.TestCase):
# The prediction and groundtruth are curated to produce very low loss.
self.assertGreater(0.01, loss)
default_value = self.model._center_params.use_only_known_classes
default_value = self.model._center_params.use_labeled_classes
self.model._center_params = (
self.model._center_params._replace(use_only_known_classes=True))
self.model._center_params._replace(use_labeled_classes=True))
loss = self.model._compute_object_center_loss(
object_center_predictions=self.prediction_dict[cnma.OBJECT_CENTER],
input_height=self.input_height,
input_width=self.input_width,
per_pixel_weights=self.per_pixel_weights)
self.model._center_params = (
self.model._center_params._replace(
use_only_known_classes=default_value))
self.model._center_params._replace(use_labeled_classes=default_value))
# The prediction and groundtruth are curated to produce very low loss.
self.assertGreater(0.01, loss)
......
......@@ -73,6 +73,14 @@ message CenterNet {
// If set, loss is only computed for the labeled classes.
optional bool use_labeled_classes = 6 [default = false];
// The keypoint weights used for calculating the location of object center.
// When the field is provided, the number of weights need to be the same as
// the number of keypoints. The object center is calculated by the weighted
// mean of the keypoint locations. When the field is not provided, the
// object center is determined by the bounding box groundtruth annotations
// (default behavior).
repeated float keypoint_weights_for_center = 7;
}
optional ObjectCenterParams object_center_params = 5;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment