Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8cf196cc
"test/vscode:/vscode.git/clone" did not exist on "91847e382acbfdec0a40a58918e47bbd2191ef6a"
Commit
8cf196cc
authored
Apr 26, 2021
by
A. Unique TensorFlower
Committed by
TF Object Detection Team
Apr 26, 2021
Browse files
Prevent float16 accuracy issues
PiperOrigin-RevId: 370575507
parent
95d1b298
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
10 deletions
+50
-10
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+34
-9
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
...ction/meta_architectures/center_net_meta_arch_tf2_test.py
+16
-1
No files found.
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
8cf196cc
...
...
@@ -571,15 +571,11 @@ def argmax_feature_map_locations(feature_map):
feature_map
,
[
batch_size
,
-
1
,
num_channels
])
peak_flat_indices
=
tf
.
math
.
argmax
(
feature_map_flattened
,
axis
=
1
,
output_type
=
tf
.
dtypes
.
int32
)
# Convert the indices such that they represent the location in the full
# (flattened) feature map of size [batch, height * width * channels].
channel_idx
=
tf
.
range
(
num_channels
)[
tf
.
newaxis
,
:]
peak_flat_indices
=
num_channels
*
peak_flat_indices
+
channel_idx
# Get x, y and channel indices corresponding to the top indices in the flat
# array.
y_indices
,
x_indices
,
channel_indices
=
(
row_col_channel_indices_from_flattened_indices
(
peak_flat_indices
,
width
,
num_channels
))
# Get x and y indices corresponding to the top indices in the flat array.
y_indices
,
x_indices
=
(
row_col_indices_from_flattened_indices
(
peak_flat_indices
,
width
))
channel_indices
=
tf
.
tile
(
tf
.
range
(
num_channels
)[
tf
.
newaxis
,
:],
[
batch_size
,
1
])
return
y_indices
,
x_indices
,
channel_indices
...
...
@@ -1247,6 +1243,12 @@ def row_col_channel_indices_from_flattened_indices(indices, num_cols,
indices.
"""
# Be careful with this function when running a model in float16 precision
# (e.g. TF.js with WebGL) because the array indices may not be represented
# accurately if they are too large, resulting in incorrect channel indices.
# See:
# https://en.wikipedia.org/wiki/Half-precision_floating-point_format#Precision_limitations_on_integer_values
#
# Avoid using mod operator to make the ops more easy to be compatible with
# different environments, e.g. WASM.
row_indices
=
(
indices
//
num_channels
)
//
num_cols
...
...
@@ -1257,6 +1259,29 @@ def row_col_channel_indices_from_flattened_indices(indices, num_cols,
return
row_indices
,
col_indices
,
channel_indices
def
row_col_indices_from_flattened_indices
(
indices
,
num_cols
):
"""Computes row and column indices from flattened indices.
Args:
indices: An integer tensor of any shape holding the indices in the flattened
space.
num_cols: Number of columns in the image (width).
Returns:
row_indices: The row indices corresponding to each of the input indices.
Same shape as indices.
col_indices: The column indices corresponding to each of the input indices.
Same shape as indices.
"""
# Avoid using mod operator to make the ops more easy to be compatible with
# different environments, e.g. WASM.
row_indices
=
indices
//
num_cols
col_indices
=
indices
-
row_indices
*
num_cols
return
row_indices
,
col_indices
def
get_valid_anchor_weights_in_flattened_image
(
true_image_shapes
,
height
,
width
):
"""Computes valid anchor weights for an image assuming pixels will be flattened.
...
...
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
View file @
8cf196cc
...
...
@@ -55,7 +55,7 @@ class CenterNetMetaArchPredictionHeadTest(
class
CenterNetMetaArchHelpersTest
(
test_case
.
TestCase
,
parameterized
.
TestCase
):
"""Test for CenterNet meta architecture related functions."""
def
test_row_col_indices_from_flattened_indices
(
self
):
def
test_row_col_
channel_
indices_from_flattened_indices
(
self
):
"""Tests that the computation of row, col, channel indices is correct."""
r_grid
,
c_grid
,
ch_grid
=
(
np
.
zeros
((
5
,
4
,
3
),
dtype
=
np
.
int
),
...
...
@@ -89,6 +89,21 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
np
.
testing
.
assert_array_equal
(
ci
,
c_grid
.
flatten
())
np
.
testing
.
assert_array_equal
(
chi
,
ch_grid
.
flatten
())
def
test_row_col_indices_from_flattened_indices
(
self
):
"""Tests that the computation of row, col indices is correct."""
r_grid
=
np
.
array
([[
0
,
0
,
0
,
0
],
[
1
,
1
,
1
,
1
],
[
2
,
2
,
2
,
2
],
[
3
,
3
,
3
,
3
],
[
4
,
4
,
4
,
4
]])
c_grid
=
np
.
array
([[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
]])
indices
=
np
.
arange
(
20
)
ri
,
ci
,
=
cnma
.
row_col_indices_from_flattened_indices
(
indices
,
4
)
np
.
testing
.
assert_array_equal
(
ri
,
r_grid
.
flatten
())
np
.
testing
.
assert_array_equal
(
ci
,
c_grid
.
flatten
())
def
test_flattened_indices_from_row_col_indices
(
self
):
r
=
np
.
array
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment