Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3d9ae6de
Commit
3d9ae6de
authored
Apr 13, 2021
by
Ronny Votel
Committed by
TF Object Detection Team
Apr 13, 2021
Browse files
Replacing tf.gather_nd since reshaping and applying tf.gather executes faster during training.
PiperOrigin-RevId: 368217758
parent
e5459a6b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
51 additions
and
11 deletions
+51
-11
research/object_detection/core/target_assigner.py
research/object_detection/core/target_assigner.py
+14
-1
research/object_detection/core/target_assigner_test.py
research/object_detection/core/target_assigner_test.py
+34
-8
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+3
-2
No files found.
research/object_detection/core/target_assigner.py
View file @
3d9ae6de
...
@@ -812,7 +812,20 @@ def get_batch_predictions_from_indices(batch_predictions, indices):
...
@@ -812,7 +812,20 @@ def get_batch_predictions_from_indices(batch_predictions, indices):
values: A tensor of shape [num_instances, channels] holding the predicted
values: A tensor of shape [num_instances, channels] holding the predicted
values at the given indices.
values at the given indices.
"""
"""
return
tf
.
gather_nd
(
batch_predictions
,
indices
)
# Note, gather_nd (and its gradient scatter_nd) runs significantly slower (on
# TPU) than gather with flattened inputs, so reshape the tensor, flatten the
# indices, and run gather.
shape
=
shape_utils
.
combined_static_and_dynamic_shape
(
batch_predictions
)
# [B, H, W, C] -> [H*W, W, 1] or [B, H, W, N, C] -> [H*W*N, W*N, N, 1]
rev_cum_interior_indices
=
tf
.
reverse
(
tf
.
math
.
cumprod
(
shape
[
-
2
:
0
:
-
1
]),
[
0
])
rev_cum_interior_indices
=
tf
.
concat
([
rev_cum_interior_indices
,
[
1
]],
axis
=
0
)
# Compute flattened indices and gather.
flattened_inds
=
tf
.
linalg
.
matmul
(
indices
,
rev_cum_interior_indices
[:,
tf
.
newaxis
])[:,
0
]
batch_predictions_2d
=
tf
.
reshape
(
batch_predictions
,
[
-
1
,
shape
[
-
1
]])
return
tf
.
gather
(
batch_predictions_2d
,
flattened_inds
,
axis
=
0
)
def
_compute_std_dev_from_box_size
(
boxes_height
,
boxes_width
,
min_overlap
):
def
_compute_std_dev_from_box_size
(
boxes_height
,
boxes_width
,
min_overlap
):
...
...
research/object_detection/core/target_assigner_test.py
View file @
3d9ae6de
...
@@ -1628,22 +1628,48 @@ class CenterNetBoxTargetAssignerTest(test_case.TestCase):
...
@@ -1628,22 +1628,48 @@ class CenterNetBoxTargetAssignerTest(test_case.TestCase):
"""
"""
def
graph_fn
():
def
graph_fn
():
box_batch
=
[
tf
.
constant
([
self
.
_box_center
,
self
.
_box_lower_left
]),
tf
.
constant
([
self
.
_box_center_small
,
self
.
_box_odd_coordinates
]),
]
pred_array
=
np
.
ones
((
2
,
40
,
20
,
2
),
dtype
=
np
.
int32
)
*
-
1000
pred_array
=
np
.
ones
((
2
,
40
,
20
,
2
),
dtype
=
np
.
int32
)
*
-
1000
pred_array
[
0
,
20
,
10
]
=
[
1
,
2
]
pred_array
[
0
,
20
,
10
]
=
[
1
,
2
]
pred_array
[
0
,
30
,
5
]
=
[
3
,
4
]
pred_array
[
0
,
30
,
5
]
=
[
3
,
4
]
pred_array
[
1
,
20
,
10
]
=
[
5
,
6
]
pred_array
[
1
,
20
,
10
]
=
[
5
,
6
]
pred_array
[
1
,
14
,
11
]
=
[
7
,
8
]
pred_array
[
1
,
14
,
11
]
=
[
7
,
8
]
pred_tensor
=
tf
.
constant
(
pred_array
)
indices
=
tf
.
constant
([
[
0
,
20
,
10
],
[
0
,
30
,
5
],
[
1
,
20
,
10
],
[
1
,
14
,
11
]
],
dtype
=
tf
.
int32
)
preds
=
targetassigner
.
get_batch_predictions_from_indices
(
pred_tensor
,
indices
)
return
preds
preds
=
self
.
execute
(
graph_fn
,
[])
np
.
testing
.
assert_array_equal
(
preds
,
[[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
]])
def
test_get_batch_predictions_from_indices_with_class
(
self
):
"""Test the get_batch_predictions_from_indices function with class axis.
This test verifies that the indices returned by
assign_size_and_offset_targets function work as expected with a predicted
tensor.
"""
def
graph_fn
():
pred_array
=
np
.
ones
((
2
,
40
,
20
,
5
,
2
),
dtype
=
np
.
int32
)
*
-
1000
pred_array
[
0
,
20
,
10
,
0
]
=
[
1
,
2
]
pred_array
[
0
,
30
,
5
,
2
]
=
[
3
,
4
]
pred_array
[
1
,
20
,
10
,
1
]
=
[
5
,
6
]
pred_array
[
1
,
14
,
11
,
4
]
=
[
7
,
8
]
pred_tensor
=
tf
.
constant
(
pred_array
)
pred_tensor
=
tf
.
constant
(
pred_array
)
cn_assigner
=
targetassigner
.
CenterNetBoxTargetAssigner
(
4
)
indices
=
tf
.
constant
([
indices
,
_
,
_
,
_
=
cn_assigner
.
assign_size_and_offset_targets
(
[
0
,
20
,
10
,
0
],
160
,
80
,
box_batch
)
[
0
,
30
,
5
,
2
],
[
1
,
20
,
10
,
1
],
[
1
,
14
,
11
,
4
]
],
dtype
=
tf
.
int32
)
preds
=
targetassigner
.
get_batch_predictions_from_indices
(
preds
=
targetassigner
.
get_batch_predictions_from_indices
(
pred_tensor
,
indices
)
pred_tensor
,
indices
)
...
...
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
3d9ae6de
...
@@ -2859,8 +2859,9 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2859,8 +2859,9 @@ class CenterNetMetaArch(model.DetectionModel):
# Keypoint offset loss.
# Keypoint offset loss.
loss
=
0.0
loss
=
0.0
for
prediction
in
depth_predictions
:
for
prediction
in
depth_predictions
:
selected_depths
=
cn_assigner
.
get_batch_predictions_from_indices
(
# TODO(yuhuic): Update this function to use
prediction
,
batch_indices
)
# cn_assigner.get_batch_predictions_from_indices().
selected_depths
=
tf
.
gather_nd
(
prediction
,
batch_indices
)
if
kp_params
.
per_keypoint_offset
and
kp_params
.
per_keypoint_depth
:
if
kp_params
.
per_keypoint_offset
and
kp_params
.
per_keypoint_depth
:
selected_depths
=
tf
.
expand_dims
(
selected_depths
,
axis
=-
1
)
selected_depths
=
tf
.
expand_dims
(
selected_depths
,
axis
=-
1
)
# The dimensions passed are not as per the doc string but the loss
# The dimensions passed are not as per the doc string but the loss
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment