Commit 6dd70405 authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Rewrote the postprocessing logics of the multi-class keypoint task in the

CenterNet meta arch to avoid using scatter_nd and tf.cond.

Also, use static_or_dynamic_map_fn in the convert_strided_predictions_to_normalized_keypoints function to avoid introducing cycles in the graph that causes the TPU converter errors.

PiperOrigin-RevId: 394152521
parent 163ca152
...@@ -1796,15 +1796,9 @@ def convert_strided_predictions_to_normalized_keypoints( ...@@ -1796,15 +1796,9 @@ def convert_strided_predictions_to_normalized_keypoints(
keypoints, window = inputs keypoints, window = inputs
return keypoint_ops.clip_to_window(keypoints, window) return keypoint_ops.clip_to_window(keypoints, window)
# Specify the TensorSpec explicitly in the tf.map_fn to make it tf.lite keypoint_coords_normalized = shape_utils.static_or_dynamic_map_fn(
# compatible. clip_to_window, [keypoint_coords_normalized, batch_window],
kpts_dims = _get_shape(keypoint_coords_normalized, 4) dtype=tf.float32, back_prop=False)
output_spec = tf.TensorSpec(
shape=[kpts_dims[1], kpts_dims[2], kpts_dims[3]], dtype=tf.float32)
keypoint_coords_normalized = tf.map_fn(
clip_to_window, (keypoint_coords_normalized, batch_window),
dtype=tf.float32, back_prop=False,
fn_output_signature=output_spec)
keypoint_scores = tf.where(valid_indices, keypoint_scores, keypoint_scores = tf.where(valid_indices, keypoint_scores,
tf.zeros_like(keypoint_scores)) tf.zeros_like(keypoint_scores))
return keypoint_coords_normalized, keypoint_scores return keypoint_coords_normalized, keypoint_scores
...@@ -4385,9 +4379,13 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -4385,9 +4379,13 @@ class CenterNetMetaArch(model.DetectionModel):
kpt_coords_for_example_list = [] kpt_coords_for_example_list = []
kpt_scores_for_example_list = [] kpt_scores_for_example_list = []
for ex_ind in range(batch_size): for ex_ind in range(batch_size):
kpt_coords_for_class_list = [] # The tensors that host the keypoint coordinates and scores for all
kpt_scores_for_class_list = [] # instances and all keypoints. They will be updated by scatter_nd_add for
instance_inds_for_class_list = [] # each keypoint tasks.
kpt_coords_for_example_all_det = tf.zeros(
[max_detections, total_num_keypoints, 2])
kpt_scores_for_example_all_det = tf.zeros(
[max_detections, total_num_keypoints])
for task_name, kp_params in self._kp_params_dict.items(): for task_name, kp_params in self._kp_params_dict.items():
keypoint_heatmap = prediction_dict[ keypoint_heatmap = prediction_dict[
get_keypoint_name(task_name, KEYPOINT_HEATMAP)][-1] get_keypoint_name(task_name, KEYPOINT_HEATMAP)][-1]
...@@ -4397,77 +4395,62 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -4397,77 +4395,62 @@ class CenterNetMetaArch(model.DetectionModel):
get_keypoint_name(task_name, KEYPOINT_REGRESSION)][-1] get_keypoint_name(task_name, KEYPOINT_REGRESSION)][-1]
instance_inds = self._get_instance_indices( instance_inds = self._get_instance_indices(
classes, num_detections, ex_ind, kp_params.class_id) classes, num_detections, ex_ind, kp_params.class_id)
num_ind = _get_shape(instance_inds, 1)
# Gather the feature map locations corresponding to the object class.
def true_fn(keypoint_heatmap, keypoint_offsets, keypoint_regression, y_indices_for_kpt_class = tf.gather(y_indices, instance_inds, axis=1)
classes, y_indices, x_indices, boxes, instance_inds, ex_ind, x_indices_for_kpt_class = tf.gather(x_indices, instance_inds, axis=1)
kp_params): if boxes is None:
"""Logics to execute when instance_inds is not an empty set.""" boxes_for_kpt_class = None
# Gather the feature map locations corresponding to the object class. else:
y_indices_for_kpt_class = tf.gather(y_indices, instance_inds, axis=1) boxes_for_kpt_class = tf.gather(boxes, instance_inds, axis=1)
x_indices_for_kpt_class = tf.gather(x_indices, instance_inds, axis=1)
if boxes is None: # Postprocess keypoints and scores for class and single image. Shapes
boxes_for_kpt_class = None # are [1, num_instances_i, num_keypoints_i, 2] and
else: # [1, num_instances_i, num_keypoints_i], respectively. Note that
boxes_for_kpt_class = tf.gather(boxes, instance_inds, axis=1) # num_instances_i and num_keypoints_i refers to the number of
# instances and keypoints for class i, respectively.
# Postprocess keypoints and scores for class and single image. Shapes (kpt_coords_for_class, kpt_scores_for_class, _) = (
# are [1, num_instances_i, num_keypoints_i, 2] and self._postprocess_keypoints_for_class_and_image(
# [1, num_instances_i, num_keypoints_i], respectively. Note that keypoint_heatmap,
# num_instances_i and num_keypoints_i refers to the number of keypoint_offsets,
# instances and keypoints for class i, respectively. keypoint_regression,
(kpt_coords_for_class, kpt_scores_for_class, _) = ( classes,
self._postprocess_keypoints_for_class_and_image( y_indices_for_kpt_class,
keypoint_heatmap, x_indices_for_kpt_class,
keypoint_offsets, boxes_for_kpt_class,
keypoint_regression, ex_ind,
classes, kp_params,
y_indices_for_kpt_class, ))
x_indices_for_kpt_class,
boxes_for_kpt_class, # Prepare the indices for scatter_nd. The resulting combined_inds has
ex_ind, # the shape of [num_instances_i * num_keypoints_i, 2], where the first
kp_params, # column corresponds to the instance IDs and the second column
)) # corresponds to the keypoint IDs.
kpt_inds = tf.constant(kp_params.keypoint_indices, dtype=tf.int32)
# Expand keypoint dimension (with padding) so that coordinates and kpt_inds = tf.expand_dims(kpt_inds, axis=0)
# scores have shape [1, num_instances_i, num_total_keypoints, 2] and instance_inds_expand = tf.expand_dims(instance_inds, axis=-1)
# [1, num_instances_i, num_total_keypoints], respectively. kpt_inds_expand = kpt_inds * tf.ones_like(instance_inds_expand)
kpts_coords_for_class_padded, kpt_scores_for_class_padded = ( instance_inds_expand = instance_inds_expand * tf.ones_like(kpt_inds)
_pad_to_full_keypoint_dim(kpt_coords_for_class, combined_inds = tf.stack(
kpt_scores_for_class, [instance_inds_expand, kpt_inds_expand], axis=2)
kp_params.keypoint_indices, combined_inds = tf.reshape(combined_inds, [-1, 2])
total_num_keypoints))
return kpts_coords_for_class_padded, kpt_scores_for_class_padded # Reshape the keypoint coordinates/scores to [num_instances_i *
# num_keypoints_i, 2]/[num_instances_i * num_keypoints_i] to be used
def false_fn(): # by scatter_nd_add.
"""Logics to execute when the instance_inds is an empty set.""" kpt_coords_for_class = tf.reshape(kpt_coords_for_class, [-1, 2])
return (tf.zeros([1, 0, total_num_keypoints, 2], dtype=tf.float32), kpt_scores_for_class = tf.reshape(kpt_scores_for_class, [-1])
tf.zeros([1, 0, total_num_keypoints], dtype=tf.float32)) kpt_coords_for_example_all_det = tf.tensor_scatter_nd_add(
kpt_coords_for_example_all_det,
true_fn = functools.partial( combined_inds, kpt_coords_for_class)
true_fn, keypoint_heatmap, keypoint_offsets, keypoint_regression, kpt_scores_for_example_all_det = tf.tensor_scatter_nd_add(
classes, y_indices, x_indices, boxes, instance_inds, ex_ind, kpt_scores_for_example_all_det,
kp_params) combined_inds, kpt_scores_for_class)
# Use dimension values instead of tf.size for tf.lite compatibility.
results = tf.cond(num_ind[0] > 0, true_fn, false_fn) kpt_coords_for_example_list.append(
tf.expand_dims(kpt_coords_for_example_all_det, axis=0))
kpt_coords_for_class_list.append(results[0]) kpt_scores_for_example_list.append(
kpt_scores_for_class_list.append(results[1]) tf.expand_dims(kpt_scores_for_example_all_det, axis=0))
instance_inds_for_class_list.append(instance_inds)
# Concatenate all keypoints across all classes (single example).
kpt_coords_for_example = tf.concat(kpt_coords_for_class_list, axis=1)
kpt_scores_for_example = tf.concat(kpt_scores_for_class_list, axis=1)
instance_inds_for_example = tf.concat(instance_inds_for_class_list,
axis=0)
(kpt_coords_for_example_all_det,
kpt_scores_for_example_all_det) = self._scatter_keypoints_to_batch(
num_ind, kpt_coords_for_example, kpt_scores_for_example,
instance_inds_for_example, max_detections, total_num_keypoints)
kpt_coords_for_example_list.append(kpt_coords_for_example_all_det)
kpt_scores_for_example_list.append(kpt_scores_for_example_all_det)
# Concatenate all keypoints and scores from all examples in the batch. # Concatenate all keypoints and scores from all examples in the batch.
# Shapes are [batch_size, max_detections, num_total_keypoints, 2] and # Shapes are [batch_size, max_detections, num_total_keypoints, 2] and
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment