Commit 191d9624 authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Updated the postprocessing ops to avoid reshaping tensors with dimension > 4.

This allows the exported TF Lite MoveNet model to be able to run on device GPU
using OpenCL/OpenGL.

PiperOrigin-RevId: 381343791
parent 1fed4144
...@@ -598,7 +598,7 @@ def prediction_tensors_to_single_instance_kpts( ...@@ -598,7 +598,7 @@ def prediction_tensors_to_single_instance_kpts(
keypoint type, as it's possible to filter some candidates due to the score keypoint type, as it's possible to filter some candidates due to the score
threshold. threshold.
""" """
batch_size, height, width, num_keypoints = _get_shape( batch_size, _, _, num_keypoints = _get_shape(
keypoint_heatmap_predictions, 4) keypoint_heatmap_predictions, 4)
# Get x, y and channel indices corresponding to the top indices in the # Get x, y and channel indices corresponding to the top indices in the
# keypoint heatmap predictions. # keypoint heatmap predictions.
...@@ -612,24 +612,32 @@ def prediction_tensors_to_single_instance_kpts( ...@@ -612,24 +612,32 @@ def prediction_tensors_to_single_instance_kpts(
_multi_range(batch_size, value_repetitions=num_keypoints), _multi_range(batch_size, value_repetitions=num_keypoints),
tf.reshape(y_indices, [-1]), tf.reshape(y_indices, [-1]),
tf.reshape(x_indices, [-1]), tf.reshape(x_indices, [-1]),
tf.reshape(channel_indices, [-1])
], axis=1) ], axis=1)
# Reshape the offsets predictions to shape: # shape: [num_keypoints, num_keypoints * 2]
# [batch_size, height, width, num_keypoints, 2]
keypoint_heatmap_offsets = tf.reshape(
keypoint_heatmap_offsets, [batch_size, height, width, num_keypoints, -1])
# shape: [num_keypoints, 2]
selected_offsets_flat = tf.gather_nd(keypoint_heatmap_offsets, selected_offsets_flat = tf.gather_nd(keypoint_heatmap_offsets,
combined_indices) combined_indices)
y_offsets, x_offsets = tf.unstack(selected_offsets_flat, axis=1) # shape: [num_keypoints, num_keypoints, 2].
selected_offsets_flat = tf.reshape(
selected_offsets_flat, [num_keypoints, num_keypoints, -1])
# shape: [num_keypoints].
channel_indices = tf.keras.backend.flatten(channel_indices)
# shape: [num_keypoints, 2].
retrieve_indices = tf.stack([channel_indices, channel_indices], axis=1)
# shape: [num_keypoints, 2]
selected_offsets = tf.gather_nd(selected_offsets_flat, retrieve_indices)
y_offsets, x_offsets = tf.unstack(selected_offsets, axis=1)
keypoint_candidates = tf.stack([ keypoint_candidates = tf.stack([
tf.cast(y_indices, dtype=tf.float32) + tf.expand_dims(y_offsets, axis=0), tf.cast(y_indices, dtype=tf.float32) + tf.expand_dims(y_offsets, axis=0),
tf.cast(x_indices, dtype=tf.float32) + tf.expand_dims(x_offsets, axis=0) tf.cast(x_indices, dtype=tf.float32) + tf.expand_dims(x_offsets, axis=0)
], axis=2) ], axis=2)
keypoint_candidates = tf.expand_dims(keypoint_candidates, axis=0) keypoint_candidates = tf.expand_dims(keypoint_candidates, axis=0)
# Append the channel indices back to retrieve the keypoint scores from the
# heatmap.
combined_indices = tf.concat(
[combined_indices, tf.expand_dims(channel_indices, axis=-1)], axis=1)
if keypoint_score_heatmap is None: if keypoint_score_heatmap is None:
keypoint_scores = tf.gather_nd( keypoint_scores = tf.gather_nd(
keypoint_heatmap_predictions, combined_indices) keypoint_heatmap_predictions, combined_indices)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment