Commit 071b3b94 authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Updated logics and ops in CenterNet postprocessing functions to make the model

(object detection/keypoint prediction tasks) tf.lite compatible.

PiperOrigin-RevId: 333766281
parent bc7c670f
...@@ -588,15 +588,23 @@ def refine_keypoints(regressed_keypoints, keypoint_candidates, keypoint_scores, ...@@ -588,15 +588,23 @@ def refine_keypoints(regressed_keypoints, keypoint_candidates, keypoint_scores,
# Pairwise squared distances between regressed keypoints and candidate # Pairwise squared distances between regressed keypoints and candidate
# keypoints (for a single keypoint type). # keypoints (for a single keypoint type).
# Shape [batch_size, num_instances, max_candidates, num_keypoints]. # Shape [batch_size, num_instances, 1, num_keypoints, 2].
regressed_keypoint_expanded = tf.expand_dims(regressed_keypoints, regressed_keypoint_expanded = tf.expand_dims(regressed_keypoints,
axis=2) axis=2)
# Shape [batch_size, 1, max_candidates, num_keypoints, 2].
keypoint_candidates_expanded = tf.expand_dims( keypoint_candidates_expanded = tf.expand_dims(
keypoint_candidates_with_nans, axis=1) keypoint_candidates_with_nans, axis=1)
sqrd_distances = tf.math.reduce_sum( # Use explicit tensor shape broadcasting (since the tensor dimensions are
tf.math.squared_difference(regressed_keypoint_expanded, # expanded to 5D) to make it tf.lite compatible.
keypoint_candidates_expanded), regressed_keypoint_expanded = tf.tile(
axis=-1) regressed_keypoint_expanded, multiples=[1, 1, max_candidates, 1, 1])
keypoint_candidates_expanded = tf.tile(
keypoint_candidates_expanded, multiples=[1, num_instances, 1, 1, 1])
# Replace tf.math.squared_difference by "-" operator and tf.multiply ops since
# tf.lite convert doesn't support squared_difference with undetermined
# dimension.
diff = regressed_keypoint_expanded - keypoint_candidates_expanded
sqrd_distances = tf.math.reduce_sum(tf.multiply(diff, diff), axis=-1)
distances = tf.math.sqrt(sqrd_distances) distances = tf.math.sqrt(sqrd_distances)
# Determine the candidates that have the minimum distance to the regressed # Determine the candidates that have the minimum distance to the regressed
...@@ -968,9 +976,16 @@ def convert_strided_predictions_to_normalized_keypoints( ...@@ -968,9 +976,16 @@ def convert_strided_predictions_to_normalized_keypoints(
def clip_to_window(inputs): def clip_to_window(inputs):
keypoints, window = inputs keypoints, window = inputs
return keypoint_ops.clip_to_window(keypoints, window) return keypoint_ops.clip_to_window(keypoints, window)
# Specify the TensorSpec explicitly in the tf.map_fn to make it tf.lite
# compatible.
kpts_dims = _get_shape(keypoint_coords_normalized, 4)
output_spec = tf.TensorSpec(
shape=[kpts_dims[1], kpts_dims[2], kpts_dims[3]], dtype=tf.float32)
keypoint_coords_normalized = tf.map_fn( keypoint_coords_normalized = tf.map_fn(
clip_to_window, (keypoint_coords_normalized, batch_window), clip_to_window, (keypoint_coords_normalized, batch_window),
dtype=tf.float32, back_prop=False) dtype=tf.float32, back_prop=False,
fn_output_signature=output_spec)
keypoint_scores = tf.where(valid_indices, keypoint_scores, keypoint_scores = tf.where(valid_indices, keypoint_scores,
tf.zeros_like(keypoint_scores)) tf.zeros_like(keypoint_scores))
return keypoint_coords_normalized, keypoint_scores return keypoint_coords_normalized, keypoint_scores
...@@ -2891,6 +2906,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2891,6 +2906,7 @@ class CenterNetMetaArch(model.DetectionModel):
get_keypoint_name(task_name, KEYPOINT_REGRESSION)][-1] get_keypoint_name(task_name, KEYPOINT_REGRESSION)][-1]
instance_inds = self._get_instance_indices( instance_inds = self._get_instance_indices(
classes, num_detections, ex_ind, kp_params.class_id) classes, num_detections, ex_ind, kp_params.class_id)
num_ind = _get_shape(instance_inds, 1)
def true_fn( def true_fn(
keypoint_heatmap, keypoint_offsets, keypoint_regression, keypoint_heatmap, keypoint_offsets, keypoint_regression,
...@@ -2925,7 +2941,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2925,7 +2941,8 @@ class CenterNetMetaArch(model.DetectionModel):
true_fn, keypoint_heatmap, keypoint_offsets, keypoint_regression, true_fn, keypoint_heatmap, keypoint_offsets, keypoint_regression,
classes, y_indices, x_indices, boxes, instance_inds, ex_ind, classes, y_indices, x_indices, boxes, instance_inds, ex_ind,
kp_params) kp_params)
results = tf.cond(tf.size(instance_inds) > 0, true_fn, false_fn) # Use dimension values instead of tf.size for tf.lite compatibility.
results = tf.cond(num_ind[0] > 0, true_fn, false_fn)
kpt_coords_for_class_list.append(results[0]) kpt_coords_for_class_list.append(results[0])
kpt_scores_for_class_list.append(results[1]) kpt_scores_for_class_list.append(results[1])
...@@ -2937,7 +2954,9 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2937,7 +2954,9 @@ class CenterNetMetaArch(model.DetectionModel):
instance_inds_for_example = tf.concat(instance_inds_for_class_list, instance_inds_for_example = tf.concat(instance_inds_for_class_list,
axis=0) axis=0)
if tf.size(instance_inds_for_example) > 0: # Use dimension values instead of tf.size for tf.lite compatibility.
num_inds = _get_shape(instance_inds_for_example, 1)
if num_inds[0] > 0:
# Scatter into tensor where instances align with original detection # Scatter into tensor where instances align with original detection
# instances. New shape of keypoint coordinates and scores are # instances. New shape of keypoint coordinates and scores are
# [1, max_detections, num_total_keypoints, 2] and # [1, max_detections, num_total_keypoints, 2] and
...@@ -2977,7 +2996,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2977,7 +2996,7 @@ class CenterNetMetaArch(model.DetectionModel):
class_id: Class id class_id: Class id
Returns: Returns:
instance_inds: A [num_instances] int tensor where each element indicates instance_inds: A [num_instances] int32 tensor where each element indicates
the instance location within the `classes` tensor. This is useful to the instance location within the `classes` tensor. This is useful to
associate the refined keypoints with the original detections (i.e. associate the refined keypoints with the original detections (i.e.
boxes) boxes)
...@@ -2986,11 +3005,14 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2986,11 +3005,14 @@ class CenterNetMetaArch(model.DetectionModel):
_, max_detections = shape_utils.combined_static_and_dynamic_shape( _, max_detections = shape_utils.combined_static_and_dynamic_shape(
classes) classes)
# Get the detection indices corresponding to the target class. # Get the detection indices corresponding to the target class.
# Call tf.math.equal with matched tensor shape to make it tf.lite
# compatible.
valid_detections_with_kpt_class = tf.math.logical_and( valid_detections_with_kpt_class = tf.math.logical_and(
tf.range(max_detections) < num_detections[batch_index], tf.range(max_detections) < num_detections[batch_index],
classes[0] == class_id) tf.math.equal(classes[0], tf.fill(classes[0].shape, class_id)))
instance_inds = tf.where(valid_detections_with_kpt_class)[:, 0] instance_inds = tf.where(valid_detections_with_kpt_class)[:, 0]
return instance_inds # Cast the indices tensor to int32 for tf.lite compatibility.
return tf.cast(instance_inds, tf.int32)
def _postprocess_keypoints_for_class_and_image( def _postprocess_keypoints_for_class_and_image(
self, keypoint_heatmap, keypoint_offsets, keypoint_regression, classes, self, keypoint_heatmap, keypoint_offsets, keypoint_regression, classes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment