Commit e5459a6b authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Clip values in strided box predictions.

PiperOrigin-RevId: 368215027
parent 87796817
...@@ -361,6 +361,8 @@ def prediction_tensors_to_boxes(y_indices, x_indices, height_width_predictions, ...@@ -361,6 +361,8 @@ def prediction_tensors_to_boxes(y_indices, x_indices, height_width_predictions,
the raw bounding box coordinates of boxes. the raw bounding box coordinates of boxes.
""" """
batch_size, num_boxes = _get_shape(y_indices, 2) batch_size, num_boxes = _get_shape(y_indices, 2)
_, height, width, _ = _get_shape(height_width_predictions, 4)
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
# TF Lite does not support tf.gather with batch_dims > 0, so we need to use # TF Lite does not support tf.gather with batch_dims > 0, so we need to use
# tf_gather_nd instead and here we prepare the indices for that. # tf_gather_nd instead and here we prepare the indices for that.
...@@ -382,10 +384,16 @@ def prediction_tensors_to_boxes(y_indices, x_indices, height_width_predictions, ...@@ -382,10 +384,16 @@ def prediction_tensors_to_boxes(y_indices, x_indices, height_width_predictions,
heights, widths = tf.unstack(height_width, axis=2) heights, widths = tf.unstack(height_width, axis=2)
y_offsets, x_offsets = tf.unstack(offsets, axis=2) y_offsets, x_offsets = tf.unstack(offsets, axis=2)
boxes = tf.stack([y_indices + y_offsets - heights / 2.0, ymin = y_indices + y_offsets - heights / 2.0
x_indices + x_offsets - widths / 2.0, xmin = x_indices + x_offsets - widths / 2.0
y_indices + y_offsets + heights / 2.0, ymax = y_indices + y_offsets + heights / 2.0
x_indices + x_offsets + widths / 2.0], axis=2) xmax = x_indices + x_offsets + widths / 2.0
ymin = tf.clip_by_value(ymin, 0., height)
xmin = tf.clip_by_value(xmin, 0., width)
ymax = tf.clip_by_value(ymax, 0., height)
xmax = tf.clip_by_value(xmax, 0., width)
boxes = tf.stack([ymin, xmin, ymax, xmax], axis=2)
return boxes return boxes
......
...@@ -600,7 +600,7 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -600,7 +600,7 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
boxes = self.execute(graph_fn, []) boxes = self.execute(graph_fn, [])
np.testing.assert_allclose( np.testing.assert_allclose(
[[-9, -8, 31, 52], [25, 35, 75, 85]], boxes[0]) [[0, 0, 31, 52], [25, 35, 75, 85]], boxes[0])
np.testing.assert_allclose( np.testing.assert_allclose(
[[96, 98, 106, 108], [96, 98, 106, 108]], boxes[1]) [[96, 98, 106, 108], [96, 98, 106, 108]], boxes[1])
np.testing.assert_allclose( np.testing.assert_allclose(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment