Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8cf196cc
Commit
8cf196cc
authored
Apr 26, 2021
by
A. Unique TensorFlower
Committed by
TF Object Detection Team
Apr 26, 2021
Browse files
Prevent float16 accuracy issues
PiperOrigin-RevId: 370575507
parent
95d1b298
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
10 deletions
+50
-10
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+34
-9
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
...ction/meta_architectures/center_net_meta_arch_tf2_test.py
+16
-1
No files found.
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
8cf196cc
...
@@ -571,15 +571,11 @@ def argmax_feature_map_locations(feature_map):
...
@@ -571,15 +571,11 @@ def argmax_feature_map_locations(feature_map):
feature_map
,
[
batch_size
,
-
1
,
num_channels
])
feature_map
,
[
batch_size
,
-
1
,
num_channels
])
peak_flat_indices
=
tf
.
math
.
argmax
(
peak_flat_indices
=
tf
.
math
.
argmax
(
feature_map_flattened
,
axis
=
1
,
output_type
=
tf
.
dtypes
.
int32
)
feature_map_flattened
,
axis
=
1
,
output_type
=
tf
.
dtypes
.
int32
)
# Convert the indices such that they represent the location in the full
# Get x and y indices corresponding to the top indices in the flat array.
# (flattened) feature map of size [batch, height * width * channels].
y_indices
,
x_indices
=
(
channel_idx
=
tf
.
range
(
num_channels
)[
tf
.
newaxis
,
:]
row_col_indices_from_flattened_indices
(
peak_flat_indices
,
width
))
peak_flat_indices
=
num_channels
*
peak_flat_indices
+
channel_idx
channel_indices
=
tf
.
tile
(
# Get x, y and channel indices corresponding to the top indices in the flat
tf
.
range
(
num_channels
)[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
# array.
y_indices
,
x_indices
,
channel_indices
=
(
row_col_channel_indices_from_flattened_indices
(
peak_flat_indices
,
width
,
num_channels
))
return
y_indices
,
x_indices
,
channel_indices
return
y_indices
,
x_indices
,
channel_indices
...
@@ -1247,6 +1243,12 @@ def row_col_channel_indices_from_flattened_indices(indices, num_cols,
...
@@ -1247,6 +1243,12 @@ def row_col_channel_indices_from_flattened_indices(indices, num_cols,
indices.
indices.
"""
"""
# Be careful with this function when running a model in float16 precision
# (e.g. TF.js with WebGL) because the array indices may not be represented
# accurately if they are too large, resulting in incorrect channel indices.
# See:
# https://en.wikipedia.org/wiki/Half-precision_floating-point_format#Precision_limitations_on_integer_values
#
# Avoid using mod operator to make the ops more easy to be compatible with
# Avoid using mod operator to make the ops more easy to be compatible with
# different environments, e.g. WASM.
# different environments, e.g. WASM.
row_indices
=
(
indices
//
num_channels
)
//
num_cols
row_indices
=
(
indices
//
num_channels
)
//
num_cols
...
@@ -1257,6 +1259,29 @@ def row_col_channel_indices_from_flattened_indices(indices, num_cols,
...
@@ -1257,6 +1259,29 @@ def row_col_channel_indices_from_flattened_indices(indices, num_cols,
return
row_indices
,
col_indices
,
channel_indices
return
row_indices
,
col_indices
,
channel_indices
def
row_col_indices_from_flattened_indices
(
indices
,
num_cols
):
"""Computes row and column indices from flattened indices.
Args:
indices: An integer tensor of any shape holding the indices in the flattened
space.
num_cols: Number of columns in the image (width).
Returns:
row_indices: The row indices corresponding to each of the input indices.
Same shape as indices.
col_indices: The column indices corresponding to each of the input indices.
Same shape as indices.
"""
# Avoid using mod operator to make the ops more easy to be compatible with
# different environments, e.g. WASM.
row_indices
=
indices
//
num_cols
col_indices
=
indices
-
row_indices
*
num_cols
return
row_indices
,
col_indices
def
get_valid_anchor_weights_in_flattened_image
(
true_image_shapes
,
height
,
def
get_valid_anchor_weights_in_flattened_image
(
true_image_shapes
,
height
,
width
):
width
):
"""Computes valid anchor weights for an image assuming pixels will be flattened.
"""Computes valid anchor weights for an image assuming pixels will be flattened.
...
...
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
View file @
8cf196cc
...
@@ -55,7 +55,7 @@ class CenterNetMetaArchPredictionHeadTest(
...
@@ -55,7 +55,7 @@ class CenterNetMetaArchPredictionHeadTest(
class
CenterNetMetaArchHelpersTest
(
test_case
.
TestCase
,
parameterized
.
TestCase
):
class
CenterNetMetaArchHelpersTest
(
test_case
.
TestCase
,
parameterized
.
TestCase
):
"""Test for CenterNet meta architecture related functions."""
"""Test for CenterNet meta architecture related functions."""
def
test_row_col_indices_from_flattened_indices
(
self
):
def
test_row_col_
channel_
indices_from_flattened_indices
(
self
):
"""Tests that the computation of row, col, channel indices is correct."""
"""Tests that the computation of row, col, channel indices is correct."""
r_grid
,
c_grid
,
ch_grid
=
(
np
.
zeros
((
5
,
4
,
3
),
dtype
=
np
.
int
),
r_grid
,
c_grid
,
ch_grid
=
(
np
.
zeros
((
5
,
4
,
3
),
dtype
=
np
.
int
),
...
@@ -89,6 +89,21 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
...
@@ -89,6 +89,21 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
np
.
testing
.
assert_array_equal
(
ci
,
c_grid
.
flatten
())
np
.
testing
.
assert_array_equal
(
ci
,
c_grid
.
flatten
())
np
.
testing
.
assert_array_equal
(
chi
,
ch_grid
.
flatten
())
np
.
testing
.
assert_array_equal
(
chi
,
ch_grid
.
flatten
())
def
test_row_col_indices_from_flattened_indices
(
self
):
"""Tests that the computation of row, col indices is correct."""
r_grid
=
np
.
array
([[
0
,
0
,
0
,
0
],
[
1
,
1
,
1
,
1
],
[
2
,
2
,
2
,
2
],
[
3
,
3
,
3
,
3
],
[
4
,
4
,
4
,
4
]])
c_grid
=
np
.
array
([[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
]])
indices
=
np
.
arange
(
20
)
ri
,
ci
,
=
cnma
.
row_col_indices_from_flattened_indices
(
indices
,
4
)
np
.
testing
.
assert_array_equal
(
ri
,
r_grid
.
flatten
())
np
.
testing
.
assert_array_equal
(
ci
,
c_grid
.
flatten
())
def
test_flattened_indices_from_row_col_indices
(
self
):
def
test_flattened_indices_from_row_col_indices
(
self
):
r
=
np
.
array
(
r
=
np
.
array
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment