Commit 8cf196cc authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TF Object Detection Team
Browse files

Prevent float16 accuracy issues

PiperOrigin-RevId: 370575507
parent 95d1b298
...@@ -571,15 +571,11 @@ def argmax_feature_map_locations(feature_map): ...@@ -571,15 +571,11 @@ def argmax_feature_map_locations(feature_map):
feature_map, [batch_size, -1, num_channels]) feature_map, [batch_size, -1, num_channels])
peak_flat_indices = tf.math.argmax( peak_flat_indices = tf.math.argmax(
feature_map_flattened, axis=1, output_type=tf.dtypes.int32) feature_map_flattened, axis=1, output_type=tf.dtypes.int32)
# Convert the indices such that they represent the location in the full # Get x and y indices corresponding to the top indices in the flat array.
# (flattened) feature map of size [batch, height * width * channels]. y_indices, x_indices = (
channel_idx = tf.range(num_channels)[tf.newaxis, :] row_col_indices_from_flattened_indices(peak_flat_indices, width))
peak_flat_indices = num_channels * peak_flat_indices + channel_idx channel_indices = tf.tile(
# Get x, y and channel indices corresponding to the top indices in the flat tf.range(num_channels)[tf.newaxis, :], [batch_size, 1])
# array.
y_indices, x_indices, channel_indices = (
row_col_channel_indices_from_flattened_indices(
peak_flat_indices, width, num_channels))
return y_indices, x_indices, channel_indices return y_indices, x_indices, channel_indices
...@@ -1247,6 +1243,12 @@ def row_col_channel_indices_from_flattened_indices(indices, num_cols, ...@@ -1247,6 +1243,12 @@ def row_col_channel_indices_from_flattened_indices(indices, num_cols,
indices. indices.
""" """
# Be careful with this function when running a model in float16 precision
# (e.g. TF.js with WebGL) because the array indices may not be represented
# accurately if they are too large, resulting in incorrect channel indices.
# See:
# https://en.wikipedia.org/wiki/Half-precision_floating-point_format#Precision_limitations_on_integer_values
#
# Avoid using mod operator to make the ops more easy to be compatible with # Avoid using mod operator to make the ops more easy to be compatible with
# different environments, e.g. WASM. # different environments, e.g. WASM.
row_indices = (indices // num_channels) // num_cols row_indices = (indices // num_channels) // num_cols
...@@ -1257,6 +1259,29 @@ def row_col_channel_indices_from_flattened_indices(indices, num_cols, ...@@ -1257,6 +1259,29 @@ def row_col_channel_indices_from_flattened_indices(indices, num_cols,
return row_indices, col_indices, channel_indices return row_indices, col_indices, channel_indices
def row_col_indices_from_flattened_indices(indices, num_cols):
"""Computes row and column indices from flattened indices.
Args:
indices: An integer tensor of any shape holding the indices in the flattened
space.
num_cols: Number of columns in the image (width).
Returns:
row_indices: The row indices corresponding to each of the input indices.
Same shape as indices.
col_indices: The column indices corresponding to each of the input indices.
Same shape as indices.
"""
# Avoid using mod operator to make the ops more easy to be compatible with
# different environments, e.g. WASM.
row_indices = indices // num_cols
col_indices = indices - row_indices * num_cols
return row_indices, col_indices
def get_valid_anchor_weights_in_flattened_image(true_image_shapes, height, def get_valid_anchor_weights_in_flattened_image(true_image_shapes, height,
width): width):
"""Computes valid anchor weights for an image assuming pixels will be flattened. """Computes valid anchor weights for an image assuming pixels will be flattened.
......
...@@ -55,7 +55,7 @@ class CenterNetMetaArchPredictionHeadTest( ...@@ -55,7 +55,7 @@ class CenterNetMetaArchPredictionHeadTest(
class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
"""Test for CenterNet meta architecture related functions.""" """Test for CenterNet meta architecture related functions."""
def test_row_col_indices_from_flattened_indices(self): def test_row_col_channel_indices_from_flattened_indices(self):
"""Tests that the computation of row, col, channel indices is correct.""" """Tests that the computation of row, col, channel indices is correct."""
r_grid, c_grid, ch_grid = (np.zeros((5, 4, 3), dtype=np.int), r_grid, c_grid, ch_grid = (np.zeros((5, 4, 3), dtype=np.int),
...@@ -89,6 +89,21 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -89,6 +89,21 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
np.testing.assert_array_equal(ci, c_grid.flatten()) np.testing.assert_array_equal(ci, c_grid.flatten())
np.testing.assert_array_equal(chi, ch_grid.flatten()) np.testing.assert_array_equal(chi, ch_grid.flatten())
def test_row_col_indices_from_flattened_indices(self):
"""Tests that the computation of row, col indices is correct."""
r_grid = np.array([[0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3],
[4, 4, 4, 4]])
c_grid = np.array([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3],
[0, 1, 2, 3]])
indices = np.arange(20)
ri, ci, = cnma.row_col_indices_from_flattened_indices(indices, 4)
np.testing.assert_array_equal(ri, r_grid.flatten())
np.testing.assert_array_equal(ci, c_grid.flatten())
def test_flattened_indices_from_row_col_indices(self): def test_flattened_indices_from_row_col_indices(self):
r = np.array( r = np.array(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment