Commit fcb152bf authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TF Object Detection Team
Browse files

Limit k of top_k to the size of the dynamic input.

PiperOrigin-RevId: 399210729
parent 4b6cbef4
......@@ -317,7 +317,9 @@ def top_k_feature_map_locations(feature_map, max_pool_kernel_size=3, k=100,
feature_map_peaks_flat, axis=1, output_type=tf.dtypes.int32), axis=-1)
else:
feature_map_peaks_flat = tf.reshape(feature_map_peaks, [batch_size, -1])
scores, peak_flat_indices = tf.math.top_k(feature_map_peaks_flat, k=k)
safe_k = tf.minimum(k, tf.shape(feature_map_peaks_flat)[1])
scores, peak_flat_indices = tf.math.top_k(feature_map_peaks_flat,
k=safe_k)
# Get x, y and channel indices corresponding to the top indices in the flat
# array.
......
......@@ -469,6 +469,24 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
np.testing.assert_array_equal([1, 0, 2], x_inds[1])
np.testing.assert_array_equal([1, 0, 0], channel_inds[1])
def test_top_k_feature_map_locations_very_large(self):
feature_map_np = np.zeros((2, 3, 3, 2), dtype=np.float32)
feature_map_np[0, 2, 0, 1] = 1.0
def graph_fn():
feature_map = tf.constant(feature_map_np)
feature_map.set_shape(tf.TensorShape([2, 3, None, 2]))
scores, y_inds, x_inds, channel_inds = (
cnma.top_k_feature_map_locations(
feature_map, max_pool_kernel_size=1, k=3000))
return scores, y_inds, x_inds, channel_inds
# graph execution will fail if large k's are not handled.
scores, y_inds, x_inds, channel_inds = self.execute(graph_fn, [])
self.assertEqual(scores.shape, (2, 18))
self.assertEqual(y_inds.shape, (2, 18))
self.assertEqual(x_inds.shape, (2, 18))
self.assertEqual(channel_inds.shape, (2, 18))
def test_top_k_feature_map_locations_per_channel(self):
feature_map_np = np.zeros((2, 3, 3, 2), dtype=np.float32)
feature_map_np[0, 2, 0, 0] = 1.0 # Selected.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment