Commit 3d9ae6de authored by Ronny Votel's avatar Ronny Votel Committed by TF Object Detection Team
Browse files

Replacing tf.gather_nd since reshaping and applying tf.gather executes faster during training.

PiperOrigin-RevId: 368217758
parent e5459a6b
...@@ -812,7 +812,20 @@ def get_batch_predictions_from_indices(batch_predictions, indices): ...@@ -812,7 +812,20 @@ def get_batch_predictions_from_indices(batch_predictions, indices):
values: A tensor of shape [num_instances, channels] holding the predicted values: A tensor of shape [num_instances, channels] holding the predicted
values at the given indices. values at the given indices.
""" """
return tf.gather_nd(batch_predictions, indices) # Note, gather_nd (and its gradient scatter_nd) runs significantly slower (on
# TPU) than gather with flattened inputs, so reshape the tensor, flatten the
# indices, and run gather.
shape = shape_utils.combined_static_and_dynamic_shape(batch_predictions)
# [B, H, W, C] -> [H*W, W, 1] or [B, H, W, N, C] -> [H*W*N, W*N, N, 1]
rev_cum_interior_indices = tf.reverse(tf.math.cumprod(shape[-2:0:-1]), [0])
rev_cum_interior_indices = tf.concat([rev_cum_interior_indices, [1]], axis=0)
# Compute flattened indices and gather.
flattened_inds = tf.linalg.matmul(
indices, rev_cum_interior_indices[:, tf.newaxis])[:, 0]
batch_predictions_2d = tf.reshape(batch_predictions, [-1, shape[-1]])
return tf.gather(batch_predictions_2d, flattened_inds, axis=0)
def _compute_std_dev_from_box_size(boxes_height, boxes_width, min_overlap): def _compute_std_dev_from_box_size(boxes_height, boxes_width, min_overlap):
......
...@@ -1628,22 +1628,48 @@ class CenterNetBoxTargetAssignerTest(test_case.TestCase): ...@@ -1628,22 +1628,48 @@ class CenterNetBoxTargetAssignerTest(test_case.TestCase):
""" """
def graph_fn(): def graph_fn():
box_batch = [
tf.constant([self._box_center, self._box_lower_left]),
tf.constant([self._box_center_small, self._box_odd_coordinates]),
]
pred_array = np.ones((2, 40, 20, 2), dtype=np.int32) * -1000 pred_array = np.ones((2, 40, 20, 2), dtype=np.int32) * -1000
pred_array[0, 20, 10] = [1, 2] pred_array[0, 20, 10] = [1, 2]
pred_array[0, 30, 5] = [3, 4] pred_array[0, 30, 5] = [3, 4]
pred_array[1, 20, 10] = [5, 6] pred_array[1, 20, 10] = [5, 6]
pred_array[1, 14, 11] = [7, 8] pred_array[1, 14, 11] = [7, 8]
pred_tensor = tf.constant(pred_array)
indices = tf.constant([
[0, 20, 10],
[0, 30, 5],
[1, 20, 10],
[1, 14, 11]
], dtype=tf.int32)
preds = targetassigner.get_batch_predictions_from_indices(
pred_tensor, indices)
return preds
preds = self.execute(graph_fn, [])
np.testing.assert_array_equal(preds, [[1, 2], [3, 4], [5, 6], [7, 8]])
def test_get_batch_predictions_from_indices_with_class(self):
"""Test the get_batch_predictions_from_indices function with class axis.
This test verifies that the indices returned by
assign_size_and_offset_targets function work as expected with a predicted
tensor.
"""
def graph_fn():
pred_array = np.ones((2, 40, 20, 5, 2), dtype=np.int32) * -1000
pred_array[0, 20, 10, 0] = [1, 2]
pred_array[0, 30, 5, 2] = [3, 4]
pred_array[1, 20, 10, 1] = [5, 6]
pred_array[1, 14, 11, 4] = [7, 8]
pred_tensor = tf.constant(pred_array) pred_tensor = tf.constant(pred_array)
cn_assigner = targetassigner.CenterNetBoxTargetAssigner(4) indices = tf.constant([
indices, _, _, _ = cn_assigner.assign_size_and_offset_targets( [0, 20, 10, 0],
160, 80, box_batch) [0, 30, 5, 2],
[1, 20, 10, 1],
[1, 14, 11, 4]
], dtype=tf.int32)
preds = targetassigner.get_batch_predictions_from_indices( preds = targetassigner.get_batch_predictions_from_indices(
pred_tensor, indices) pred_tensor, indices)
......
...@@ -2859,8 +2859,9 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2859,8 +2859,9 @@ class CenterNetMetaArch(model.DetectionModel):
# Keypoint offset loss. # Keypoint offset loss.
loss = 0.0 loss = 0.0
for prediction in depth_predictions: for prediction in depth_predictions:
selected_depths = cn_assigner.get_batch_predictions_from_indices( # TODO(yuhuic): Update this function to use
prediction, batch_indices) # cn_assigner.get_batch_predictions_from_indices().
selected_depths = tf.gather_nd(prediction, batch_indices)
if kp_params.per_keypoint_offset and kp_params.per_keypoint_depth: if kp_params.per_keypoint_offset and kp_params.per_keypoint_depth:
selected_depths = tf.expand_dims(selected_depths, axis=-1) selected_depths = tf.expand_dims(selected_depths, axis=-1)
# The dimensions passed are not as per the doc string but the loss # The dimensions passed are not as per the doc string but the loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment