Commit 6dd70405 authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Rewrote the postprocessing logics of the multi-class keypoint task in the

CenterNet meta arch to avoid using scatter_nd and tf.cond.

Also, use static_or_dynamic_map_fn in the convert_strided_predictions_to_normalized_keypoints function to avoid introducing cycles in the graph that causes the TPU converter errors.

PiperOrigin-RevId: 394152521
parent 163ca152
......@@ -1796,15 +1796,9 @@ def convert_strided_predictions_to_normalized_keypoints(
keypoints, window = inputs
return keypoint_ops.clip_to_window(keypoints, window)
# Specify the TensorSpec explicitly in the tf.map_fn to make it tf.lite
# compatible.
kpts_dims = _get_shape(keypoint_coords_normalized, 4)
output_spec = tf.TensorSpec(
shape=[kpts_dims[1], kpts_dims[2], kpts_dims[3]], dtype=tf.float32)
keypoint_coords_normalized = tf.map_fn(
clip_to_window, (keypoint_coords_normalized, batch_window),
dtype=tf.float32, back_prop=False,
fn_output_signature=output_spec)
keypoint_coords_normalized = shape_utils.static_or_dynamic_map_fn(
clip_to_window, [keypoint_coords_normalized, batch_window],
dtype=tf.float32, back_prop=False)
keypoint_scores = tf.where(valid_indices, keypoint_scores,
tf.zeros_like(keypoint_scores))
return keypoint_coords_normalized, keypoint_scores
......@@ -4385,9 +4379,13 @@ class CenterNetMetaArch(model.DetectionModel):
kpt_coords_for_example_list = []
kpt_scores_for_example_list = []
for ex_ind in range(batch_size):
kpt_coords_for_class_list = []
kpt_scores_for_class_list = []
instance_inds_for_class_list = []
# The tensors that host the keypoint coordinates and scores for all
# instances and all keypoints. They will be updated by scatter_nd_add for
# each keypoint tasks.
kpt_coords_for_example_all_det = tf.zeros(
[max_detections, total_num_keypoints, 2])
kpt_scores_for_example_all_det = tf.zeros(
[max_detections, total_num_keypoints])
for task_name, kp_params in self._kp_params_dict.items():
keypoint_heatmap = prediction_dict[
get_keypoint_name(task_name, KEYPOINT_HEATMAP)][-1]
......@@ -4397,77 +4395,62 @@ class CenterNetMetaArch(model.DetectionModel):
get_keypoint_name(task_name, KEYPOINT_REGRESSION)][-1]
instance_inds = self._get_instance_indices(
classes, num_detections, ex_ind, kp_params.class_id)
num_ind = _get_shape(instance_inds, 1)
def true_fn(keypoint_heatmap, keypoint_offsets, keypoint_regression,
classes, y_indices, x_indices, boxes, instance_inds, ex_ind,
kp_params):
"""Logics to execute when instance_inds is not an empty set."""
# Gather the feature map locations corresponding to the object class.
y_indices_for_kpt_class = tf.gather(y_indices, instance_inds, axis=1)
x_indices_for_kpt_class = tf.gather(x_indices, instance_inds, axis=1)
if boxes is None:
boxes_for_kpt_class = None
else:
boxes_for_kpt_class = tf.gather(boxes, instance_inds, axis=1)
# Postprocess keypoints and scores for class and single image. Shapes
# are [1, num_instances_i, num_keypoints_i, 2] and
# [1, num_instances_i, num_keypoints_i], respectively. Note that
# num_instances_i and num_keypoints_i refers to the number of
# instances and keypoints for class i, respectively.
(kpt_coords_for_class, kpt_scores_for_class, _) = (
self._postprocess_keypoints_for_class_and_image(
keypoint_heatmap,
keypoint_offsets,
keypoint_regression,
classes,
y_indices_for_kpt_class,
x_indices_for_kpt_class,
boxes_for_kpt_class,
ex_ind,
kp_params,
))
# Expand keypoint dimension (with padding) so that coordinates and
# scores have shape [1, num_instances_i, num_total_keypoints, 2] and
# [1, num_instances_i, num_total_keypoints], respectively.
kpts_coords_for_class_padded, kpt_scores_for_class_padded = (
_pad_to_full_keypoint_dim(kpt_coords_for_class,
kpt_scores_for_class,
kp_params.keypoint_indices,
total_num_keypoints))
return kpts_coords_for_class_padded, kpt_scores_for_class_padded
def false_fn():
"""Logics to execute when the instance_inds is an empty set."""
return (tf.zeros([1, 0, total_num_keypoints, 2], dtype=tf.float32),
tf.zeros([1, 0, total_num_keypoints], dtype=tf.float32))
true_fn = functools.partial(
true_fn, keypoint_heatmap, keypoint_offsets, keypoint_regression,
classes, y_indices, x_indices, boxes, instance_inds, ex_ind,
kp_params)
# Use dimension values instead of tf.size for tf.lite compatibility.
results = tf.cond(num_ind[0] > 0, true_fn, false_fn)
kpt_coords_for_class_list.append(results[0])
kpt_scores_for_class_list.append(results[1])
instance_inds_for_class_list.append(instance_inds)
# Concatenate all keypoints across all classes (single example).
kpt_coords_for_example = tf.concat(kpt_coords_for_class_list, axis=1)
kpt_scores_for_example = tf.concat(kpt_scores_for_class_list, axis=1)
instance_inds_for_example = tf.concat(instance_inds_for_class_list,
axis=0)
(kpt_coords_for_example_all_det,
kpt_scores_for_example_all_det) = self._scatter_keypoints_to_batch(
num_ind, kpt_coords_for_example, kpt_scores_for_example,
instance_inds_for_example, max_detections, total_num_keypoints)
kpt_coords_for_example_list.append(kpt_coords_for_example_all_det)
kpt_scores_for_example_list.append(kpt_scores_for_example_all_det)
# Gather the feature map locations corresponding to the object class.
y_indices_for_kpt_class = tf.gather(y_indices, instance_inds, axis=1)
x_indices_for_kpt_class = tf.gather(x_indices, instance_inds, axis=1)
if boxes is None:
boxes_for_kpt_class = None
else:
boxes_for_kpt_class = tf.gather(boxes, instance_inds, axis=1)
# Postprocess keypoints and scores for class and single image. Shapes
# are [1, num_instances_i, num_keypoints_i, 2] and
# [1, num_instances_i, num_keypoints_i], respectively. Note that
# num_instances_i and num_keypoints_i refers to the number of
# instances and keypoints for class i, respectively.
(kpt_coords_for_class, kpt_scores_for_class, _) = (
self._postprocess_keypoints_for_class_and_image(
keypoint_heatmap,
keypoint_offsets,
keypoint_regression,
classes,
y_indices_for_kpt_class,
x_indices_for_kpt_class,
boxes_for_kpt_class,
ex_ind,
kp_params,
))
# Prepare the indices for scatter_nd. The resulting combined_inds has
# the shape of [num_instances_i * num_keypoints_i, 2], where the first
# column corresponds to the instance IDs and the second column
# corresponds to the keypoint IDs.
kpt_inds = tf.constant(kp_params.keypoint_indices, dtype=tf.int32)
kpt_inds = tf.expand_dims(kpt_inds, axis=0)
instance_inds_expand = tf.expand_dims(instance_inds, axis=-1)
kpt_inds_expand = kpt_inds * tf.ones_like(instance_inds_expand)
instance_inds_expand = instance_inds_expand * tf.ones_like(kpt_inds)
combined_inds = tf.stack(
[instance_inds_expand, kpt_inds_expand], axis=2)
combined_inds = tf.reshape(combined_inds, [-1, 2])
# Reshape the keypoint coordinates/scores to [num_instances_i *
# num_keypoints_i, 2]/[num_instances_i * num_keypoints_i] to be used
# by scatter_nd_add.
kpt_coords_for_class = tf.reshape(kpt_coords_for_class, [-1, 2])
kpt_scores_for_class = tf.reshape(kpt_scores_for_class, [-1])
kpt_coords_for_example_all_det = tf.tensor_scatter_nd_add(
kpt_coords_for_example_all_det,
combined_inds, kpt_coords_for_class)
kpt_scores_for_example_all_det = tf.tensor_scatter_nd_add(
kpt_scores_for_example_all_det,
combined_inds, kpt_scores_for_class)
kpt_coords_for_example_list.append(
tf.expand_dims(kpt_coords_for_example_all_det, axis=0))
kpt_scores_for_example_list.append(
tf.expand_dims(kpt_scores_for_example_all_det, axis=0))
# Concatenate all keypoints and scores from all examples in the batch.
# Shapes are [batch_size, max_detections, num_total_keypoints, 2] and
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment