Commit 071b3b94 authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Updated logics and ops in CenterNet postprocessing functions to make the model

(object detection/keypoint prediction tasks) tf.lite compatible.

PiperOrigin-RevId: 333766281
parent bc7c670f
......@@ -588,15 +588,23 @@ def refine_keypoints(regressed_keypoints, keypoint_candidates, keypoint_scores,
# Pairwise squared distances between regressed keypoints and candidate
# keypoints (for a single keypoint type).
# Shape [batch_size, num_instances, max_candidates, num_keypoints].
# Shape [batch_size, num_instances, 1, num_keypoints, 2].
regressed_keypoint_expanded = tf.expand_dims(regressed_keypoints,
axis=2)
# Shape [batch_size, 1, max_candidates, num_keypoints, 2].
keypoint_candidates_expanded = tf.expand_dims(
keypoint_candidates_with_nans, axis=1)
sqrd_distances = tf.math.reduce_sum(
tf.math.squared_difference(regressed_keypoint_expanded,
keypoint_candidates_expanded),
axis=-1)
# Use explicit tensor shape broadcasting (since the tensor dimensions are
# expanded to 5D) to make it tf.lite compatible.
regressed_keypoint_expanded = tf.tile(
regressed_keypoint_expanded, multiples=[1, 1, max_candidates, 1, 1])
keypoint_candidates_expanded = tf.tile(
keypoint_candidates_expanded, multiples=[1, num_instances, 1, 1, 1])
# Replace tf.math.squared_difference by "-" operator and tf.multiply ops since
# tf.lite convert doesn't support squared_difference with undetermined
# dimension.
diff = regressed_keypoint_expanded - keypoint_candidates_expanded
sqrd_distances = tf.math.reduce_sum(tf.multiply(diff, diff), axis=-1)
distances = tf.math.sqrt(sqrd_distances)
# Determine the candidates that have the minimum distance to the regressed
......@@ -968,9 +976,16 @@ def convert_strided_predictions_to_normalized_keypoints(
def clip_to_window(inputs):
keypoints, window = inputs
return keypoint_ops.clip_to_window(keypoints, window)
# Specify the TensorSpec explicitly in the tf.map_fn to make it tf.lite
# compatible.
kpts_dims = _get_shape(keypoint_coords_normalized, 4)
output_spec = tf.TensorSpec(
shape=[kpts_dims[1], kpts_dims[2], kpts_dims[3]], dtype=tf.float32)
keypoint_coords_normalized = tf.map_fn(
clip_to_window, (keypoint_coords_normalized, batch_window),
dtype=tf.float32, back_prop=False)
dtype=tf.float32, back_prop=False,
fn_output_signature=output_spec)
keypoint_scores = tf.where(valid_indices, keypoint_scores,
tf.zeros_like(keypoint_scores))
return keypoint_coords_normalized, keypoint_scores
......@@ -2891,6 +2906,7 @@ class CenterNetMetaArch(model.DetectionModel):
get_keypoint_name(task_name, KEYPOINT_REGRESSION)][-1]
instance_inds = self._get_instance_indices(
classes, num_detections, ex_ind, kp_params.class_id)
num_ind = _get_shape(instance_inds, 1)
def true_fn(
keypoint_heatmap, keypoint_offsets, keypoint_regression,
......@@ -2925,7 +2941,8 @@ class CenterNetMetaArch(model.DetectionModel):
true_fn, keypoint_heatmap, keypoint_offsets, keypoint_regression,
classes, y_indices, x_indices, boxes, instance_inds, ex_ind,
kp_params)
results = tf.cond(tf.size(instance_inds) > 0, true_fn, false_fn)
# Use dimension values instead of tf.size for tf.lite compatibility.
results = tf.cond(num_ind[0] > 0, true_fn, false_fn)
kpt_coords_for_class_list.append(results[0])
kpt_scores_for_class_list.append(results[1])
......@@ -2937,7 +2954,9 @@ class CenterNetMetaArch(model.DetectionModel):
instance_inds_for_example = tf.concat(instance_inds_for_class_list,
axis=0)
if tf.size(instance_inds_for_example) > 0:
# Use dimension values instead of tf.size for tf.lite compatibility.
num_inds = _get_shape(instance_inds_for_example, 1)
if num_inds[0] > 0:
# Scatter into tensor where instances align with original detection
# instances. New shape of keypoint coordinates and scores are
# [1, max_detections, num_total_keypoints, 2] and
......@@ -2977,7 +2996,7 @@ class CenterNetMetaArch(model.DetectionModel):
class_id: Class id
Returns:
instance_inds: A [num_instances] int tensor where each element indicates
instance_inds: A [num_instances] int32 tensor where each element indicates
the instance location within the `classes` tensor. This is useful to
associate the refined keypoints with the original detections (i.e.
boxes)
......@@ -2986,11 +3005,14 @@ class CenterNetMetaArch(model.DetectionModel):
_, max_detections = shape_utils.combined_static_and_dynamic_shape(
classes)
# Get the detection indices corresponding to the target class.
# Call tf.math.equal with matched tensor shape to make it tf.lite
# compatible.
valid_detections_with_kpt_class = tf.math.logical_and(
tf.range(max_detections) < num_detections[batch_index],
classes[0] == class_id)
tf.math.equal(classes[0], tf.fill(classes[0].shape, class_id)))
instance_inds = tf.where(valid_detections_with_kpt_class)[:, 0]
return instance_inds
# Cast the indices tensor to int32 for tf.lite compatibility.
return tf.cast(instance_inds, tf.int32)
def _postprocess_keypoints_for_class_and_image(
self, keypoint_heatmap, keypoint_offsets, keypoint_regression, classes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment