Commit 191d9624 authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Updated the postprocessing ops to avoid reshaping tensors with dimension > 4.

This allows the exported TF Lite MoveNet model to be able to run on device GPU
using OpenCL/OpenGL.

PiperOrigin-RevId: 381343791
parent 1fed4144
......@@ -598,7 +598,7 @@ def prediction_tensors_to_single_instance_kpts(
keypoint type, as it's possible to filter some candidates due to the score
threshold.
"""
batch_size, height, width, num_keypoints = _get_shape(
batch_size, _, _, num_keypoints = _get_shape(
keypoint_heatmap_predictions, 4)
# Get x, y and channel indices corresponding to the top indices in the
# keypoint heatmap predictions.
......@@ -612,24 +612,32 @@ def prediction_tensors_to_single_instance_kpts(
_multi_range(batch_size, value_repetitions=num_keypoints),
tf.reshape(y_indices, [-1]),
tf.reshape(x_indices, [-1]),
tf.reshape(channel_indices, [-1])
], axis=1)
# Reshape the offsets predictions to shape:
# [batch_size, height, width, num_keypoints, 2]
keypoint_heatmap_offsets = tf.reshape(
keypoint_heatmap_offsets, [batch_size, height, width, num_keypoints, -1])
# shape: [num_keypoints, 2]
# shape: [num_keypoints, num_keypoints * 2]
selected_offsets_flat = tf.gather_nd(keypoint_heatmap_offsets,
combined_indices)
y_offsets, x_offsets = tf.unstack(selected_offsets_flat, axis=1)
# shape: [num_keypoints, num_keypoints, 2].
selected_offsets_flat = tf.reshape(
selected_offsets_flat, [num_keypoints, num_keypoints, -1])
# shape: [num_keypoints].
channel_indices = tf.keras.backend.flatten(channel_indices)
# shape: [num_keypoints, 2].
retrieve_indices = tf.stack([channel_indices, channel_indices], axis=1)
# shape: [num_keypoints, 2]
selected_offsets = tf.gather_nd(selected_offsets_flat, retrieve_indices)
y_offsets, x_offsets = tf.unstack(selected_offsets, axis=1)
keypoint_candidates = tf.stack([
tf.cast(y_indices, dtype=tf.float32) + tf.expand_dims(y_offsets, axis=0),
tf.cast(x_indices, dtype=tf.float32) + tf.expand_dims(x_offsets, axis=0)
], axis=2)
keypoint_candidates = tf.expand_dims(keypoint_candidates, axis=0)
# Append the channel indices back to retrieve the keypoint scores from the
# heatmap.
combined_indices = tf.concat(
[combined_indices, tf.expand_dims(channel_indices, axis=-1)], axis=1)
if keypoint_score_heatmap is None:
keypoint_scores = tf.gather_nd(
keypoint_heatmap_predictions, combined_indices)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment