Commit baacb20d authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TF Object Detection Team
Browse files

TF Lite compatibility fixes for CenterNet

PiperOrigin-RevId: 335067840
parent 25b8f835
......@@ -194,6 +194,36 @@ def _flatten_spatial_dimensions(batch_images):
channels])
def _multi_range(limit,
value_repetitions=1,
range_repetitions=1,
dtype=tf.int32):
"""Creates a sequence with optional value duplication and range repetition.
As an example (see the Args section for more details),
_multi_range(limit=2, value_repetitions=3, range_repetitions=4) returns:
[0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1]
Args:
limit: A 0-D Tensor (scalar). Upper limit of sequence, exclusive.
value_repetitions: Integer. The number of times a value in the sequence is
repeated. With value_repetitions=3, the result is [0, 0, 0, 1, 1, 1, ..].
range_repetitions: Integer. The number of times the range is repeated. With
range_repetitions=3, the result is [0, 1, 2, .., 0, 1, 2, ..].
dtype: The type of the elements of the resulting tensor.
Returns:
A 1-D tensor of type `dtype` and size
[`limit` * `value_repetitions` * `range_repetitions`] that contains the
specified range with given repetitions.
"""
return tf.reshape(
tf.tile(
tf.expand_dims(tf.range(limit, dtype=dtype), axis=-1),
multiples=[range_repetitions, value_repetitions]), [-1])
def top_k_feature_map_locations(feature_map, max_pool_kernel_size=3, k=100,
per_channel=False):
"""Returns the top k scores and their locations in a feature map.
......@@ -305,21 +335,25 @@ def prediction_tensors_to_boxes(detection_scores, y_indices, x_indices,
number of boxes detected for each sample in the batch.
"""
_, _, width, _ = _get_shape(height_width_predictions, 4)
batch_size, num_boxes = _get_shape(y_indices, 2)
peak_spatial_indices = flattened_indices_from_row_col_indices(
y_indices, x_indices, width)
y_indices = _to_float32(y_indices)
x_indices = _to_float32(x_indices)
# TF Lite does not support tf.gather with batch_dims > 0, so we need to use
# tf_gather_nd instead and here we prepare the indices for that.
combined_indices = tf.stack([
_multi_range(batch_size, value_repetitions=num_boxes),
tf.reshape(y_indices, [-1]),
tf.reshape(x_indices, [-1])
], axis=1)
new_height_width = tf.gather_nd(height_width_predictions, combined_indices)
new_height_width = tf.reshape(new_height_width, [batch_size, num_boxes, -1])
height_width_flat = _flatten_spatial_dimensions(height_width_predictions)
offsets_flat = _flatten_spatial_dimensions(offset_predictions)
new_offsets = tf.gather_nd(offset_predictions, combined_indices)
offsets = tf.reshape(new_offsets, [batch_size, num_boxes, -1])
height_width = tf.gather(height_width_flat, peak_spatial_indices,
batch_dims=1)
height_width = tf.maximum(height_width, 0)
offsets = tf.gather(offsets_flat, peak_spatial_indices, batch_dims=1)
y_indices = _to_float32(y_indices)
x_indices = _to_float32(x_indices)
height_width = tf.maximum(new_height_width, 0)
heights, widths = tf.unstack(height_width, axis=2)
y_offsets, x_offsets = tf.unstack(offsets, axis=2)
......@@ -339,7 +373,7 @@ def prediction_tensors_to_temporal_offsets(
y_indices, x_indices, offset_predictions):
"""Converts CenterNet temporal offset map predictions to batched format.
This function is similiar to the box offset conversion function, as both
This function is similar to the box offset conversion function, as both
temporal offsets and box offsets are size-2 vectors.
Args:
......@@ -355,16 +389,19 @@ def prediction_tensors_to_temporal_offsets(
the object temporal offsets of (y, x) dimensions.
"""
_, _, width, _ = _get_shape(offset_predictions, 4)
batch_size, num_boxes = _get_shape(y_indices, 2)
peak_spatial_indices = flattened_indices_from_row_col_indices(
y_indices, x_indices, width)
y_indices = _to_float32(y_indices)
x_indices = _to_float32(x_indices)
# TF Lite does not support tf.gather with batch_dims > 0, so we need to use
# tf_gather_nd instead and here we prepare the indices for that.
combined_indices = tf.stack([
_multi_range(batch_size, value_repetitions=num_boxes),
tf.reshape(y_indices, [-1]),
tf.reshape(x_indices, [-1])
], axis=1)
offsets_flat = _flatten_spatial_dimensions(offset_predictions)
new_offsets = tf.gather_nd(offset_predictions, combined_indices)
offsets = tf.reshape(new_offsets, [batch_size, num_boxes, -1])
offsets = tf.gather(offsets_flat, peak_spatial_indices, batch_dims=1)
return offsets
......@@ -404,8 +441,7 @@ def prediction_tensors_to_keypoint_candidates(
keypoint type, as it's possible to filter some candidates due to the score
threshold.
"""
batch_size, _, width, num_keypoints = _get_shape(
keypoint_heatmap_predictions, 4)
batch_size, _, _, num_keypoints = _get_shape(keypoint_heatmap_predictions, 4)
# Get x, y and channel indices corresponding to the top indices in the
# keypoint heatmap predictions.
# Note that the top k candidates are produced for **each keypoint type**.
......@@ -417,19 +453,42 @@ def prediction_tensors_to_keypoint_candidates(
k=max_candidates,
per_channel=True))
peak_spatial_indices = flattened_indices_from_row_col_indices(
y_indices, x_indices, width)
# TF Lite does not support tf.gather with batch_dims > 0, so we need to use
# tf_gather_nd instead and here we prepare the indices for that.
_, num_indices = _get_shape(y_indices, 2)
combined_indices = tf.stack([
_multi_range(batch_size, value_repetitions=num_indices),
tf.reshape(y_indices, [-1]),
tf.reshape(x_indices, [-1])
], axis=1)
selected_offsets_flat = tf.gather_nd(keypoint_heatmap_offsets,
combined_indices)
selected_offsets = tf.reshape(selected_offsets_flat,
[batch_size, num_indices, -1])
y_indices = _to_float32(y_indices)
x_indices = _to_float32(x_indices)
offsets_flat = _flatten_spatial_dimensions(keypoint_heatmap_offsets)
selected_offsets = tf.gather(offsets_flat, peak_spatial_indices, batch_dims=1)
_, num_indices, num_channels = _get_shape(selected_offsets, 3)
_, _, num_channels = _get_shape(selected_offsets, 3)
if num_channels > 2:
# Offsets are per keypoint and the last dimension of selected_offsets
# contains all those offsets, so reshape the offsets to make sure that the
# last dimension contains (y_offset, x_offset) for a single keypoint.
reshaped_offsets = tf.reshape(selected_offsets,
[batch_size, num_indices, -1, 2])
offsets = tf.gather(reshaped_offsets, channel_indices, batch_dims=2)
# TF Lite does not support tf.gather with batch_dims > 0, so we need to use
# tf_gather_nd instead and here we prepare the indices for that. In this
# case, channel_indices indicates which keypoint to use the offset from.
combined_indices = tf.stack([
_multi_range(batch_size, value_repetitions=num_indices),
_multi_range(num_indices, range_repetitions=batch_size),
tf.reshape(channel_indices, [-1])
], axis=1)
offsets = tf.gather_nd(reshaped_offsets, combined_indices)
offsets = tf.reshape(offsets, [batch_size, num_indices, -1])
else:
offsets = selected_offsets
y_offsets, x_offsets = tf.unstack(offsets, axis=2)
......@@ -474,16 +533,18 @@ def regressed_keypoints_at_object_centers(regressed_keypoint_predictions,
regressed keypoints are gathered at the provided locations, and converted
to absolute coordinates in the output coordinate frame.
"""
batch_size, _, width, _ = _get_shape(regressed_keypoint_predictions, 4)
flattened_indices = flattened_indices_from_row_col_indices(
y_indices, x_indices, width)
_, num_instances = _get_shape(flattened_indices, 2)
regressed_keypoints_flat = _flatten_spatial_dimensions(
regressed_keypoint_predictions)
relative_regressed_keypoints = tf.gather(
regressed_keypoints_flat, flattened_indices, batch_dims=1)
batch_size, num_instances = _get_shape(y_indices, 2)
# TF Lite does not support tf.gather with batch_dims > 0, so we need to use
# tf_gather_nd instead and here we prepare the indices for that.
combined_indices = tf.stack([
_multi_range(batch_size, value_repetitions=num_instances),
tf.reshape(y_indices, [-1]),
tf.reshape(x_indices, [-1])
], axis=1)
relative_regressed_keypoints = tf.gather_nd(regressed_keypoint_predictions,
combined_indices)
relative_regressed_keypoints = tf.reshape(
relative_regressed_keypoints,
[batch_size, num_instances, -1, 2])
......@@ -792,22 +853,46 @@ def _gather_candidates_at_indices(keypoint_candidates, keypoint_scores,
gathered_keypoint_scores: a float tensor of shape [batch_size,
num_indices, num_keypoints, 2].
"""
batch_size, num_indices, num_keypoints = _get_shape(indices, 3)
# Transpose tensors so that all batch dimensions are up front.
keypoint_candidates_transposed = tf.transpose(keypoint_candidates,
[0, 2, 1, 3])
keypoint_scores_transposed = tf.transpose(keypoint_scores, [0, 2, 1])
nearby_candidate_inds_transposed = tf.transpose(indices,
[0, 2, 1])
nearby_candidate_coords_tranposed = tf.gather(
keypoint_candidates_transposed, nearby_candidate_inds_transposed,
batch_dims=2)
nearby_candidate_scores_transposed = tf.gather(
keypoint_scores_transposed, nearby_candidate_inds_transposed,
batch_dims=2)
gathered_keypoint_candidates = tf.transpose(nearby_candidate_coords_tranposed,
[0, 2, 1, 3])
nearby_candidate_inds_transposed = tf.transpose(indices, [0, 2, 1])
# TF Lite does not support tf.gather with batch_dims > 0, so we need to use
# tf_gather_nd instead and here we prepare the indices for that.
combined_indices = tf.stack([
_multi_range(
batch_size,
value_repetitions=num_keypoints * num_indices,
dtype=tf.int64),
_multi_range(
num_keypoints,
value_repetitions=num_indices,
range_repetitions=batch_size,
dtype=tf.int64),
tf.reshape(nearby_candidate_inds_transposed, [-1])
], axis=1)
nearby_candidate_coords_transposed = tf.gather_nd(
keypoint_candidates_transposed, combined_indices)
nearby_candidate_coords_transposed = tf.reshape(
nearby_candidate_coords_transposed,
[batch_size, num_keypoints, num_indices, -1])
nearby_candidate_scores_transposed = tf.gather_nd(keypoint_scores_transposed,
combined_indices)
nearby_candidate_scores_transposed = tf.reshape(
nearby_candidate_scores_transposed,
[batch_size, num_keypoints, num_indices])
gathered_keypoint_candidates = tf.transpose(
nearby_candidate_coords_transposed, [0, 2, 1, 3])
gathered_keypoint_scores = tf.transpose(nearby_candidate_scores_transposed,
[0, 2, 1])
return gathered_keypoint_candidates, gathered_keypoint_scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment