Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
71c7b7f9
Commit
71c7b7f9
authored
Apr 26, 2021
by
A. Unique TensorFlower
Browse files
Support variable batch size in detection generator.
PiperOrigin-RevId: 370548032
parent
0c803498
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
16 deletions
+14
-16
official/vision/beta/modeling/layers/detection_generator.py
official/vision/beta/modeling/layers/detection_generator.py
+14
-16
No files found.
official/vision/beta/modeling/layers/detection_generator.py
View file @
71c7b7f9
...
...
@@ -242,6 +242,8 @@ def _select_top_k_scores(scores_in: tf.Tensor, pre_nms_num_detections: int):
`[batch_size, pre_nms_num_detections, num_classes]`.
"""
batch_size
,
num_anchors
,
num_class
=
scores_in
.
get_shape
().
as_list
()
if
batch_size
is
None
:
batch_size
=
tf
.
shape
(
scores_in
)[
0
]
scores_trans
=
tf
.
transpose
(
scores_in
,
perm
=
[
0
,
2
,
1
])
scores_trans
=
tf
.
reshape
(
scores_trans
,
[
-
1
,
num_anchors
])
...
...
@@ -304,6 +306,8 @@ def _generate_detections_v2(boxes: tf.Tensor,
nmsed_scores
=
[]
valid_detections
=
[]
batch_size
,
_
,
num_classes_for_box
,
_
=
boxes
.
get_shape
().
as_list
()
if
batch_size
is
None
:
batch_size
=
tf
.
shape
(
boxes
)[
0
]
_
,
total_anchors
,
num_classes
=
scores
.
get_shape
().
as_list
()
# Selects top pre_nms_num scores and indices before NMS.
scores
,
indices
=
_select_top_k_scores
(
...
...
@@ -465,25 +469,20 @@ class DetectionGenerator(tf.keras.layers.Layer):
# Removes the background class.
box_scores_shape
=
tf
.
shape
(
box_scores
)
box_scores_shape_list
=
box_scores
.
get_shape
().
as_list
()
batch_size
=
box_scores_shape
[
0
]
num_locations
=
box_scores_shape
[
1
]
num_classes
=
box_scores_shape
[
-
1
]
num_locations
=
box_scores_shape
_list
[
1
]
num_classes
=
box_scores_shape
_list
[
-
1
]
num_detections
=
num_locations
*
(
num_classes
-
1
)
box_scores
=
tf
.
slice
(
box_scores
,
[
0
,
0
,
1
],
[
-
1
,
-
1
,
-
1
])
raw_boxes
=
tf
.
reshape
(
raw_boxes
,
tf
.
stack
([
batch_size
,
num_locations
,
num_classes
,
4
],
axis
=-
1
))
raw_boxes
=
tf
.
slice
(
raw_boxes
,
[
0
,
0
,
1
,
0
],
[
-
1
,
-
1
,
-
1
,
-
1
])
raw_boxes
=
tf
.
reshape
(
raw_boxes
,
[
batch_size
,
num_locations
,
num_classes
,
4
])
raw_boxes
=
tf
.
slice
(
raw_boxes
,
[
0
,
0
,
1
,
0
],
[
-
1
,
-
1
,
-
1
,
-
1
])
anchor_boxes
=
tf
.
tile
(
tf
.
expand_dims
(
anchor_boxes
,
axis
=
2
),
[
1
,
1
,
num_classes
-
1
,
1
])
raw_boxes
=
tf
.
reshape
(
raw_boxes
,
tf
.
stack
([
batch_size
,
num_detections
,
4
],
axis
=-
1
))
anchor_boxes
=
tf
.
reshape
(
anchor_boxes
,
tf
.
stack
([
batch_size
,
num_detections
,
4
],
axis
=-
1
))
raw_boxes
=
tf
.
reshape
(
raw_boxes
,
[
batch_size
,
num_detections
,
4
])
anchor_boxes
=
tf
.
reshape
(
anchor_boxes
,
[
batch_size
,
num_detections
,
4
])
# Box decoding.
decoded_boxes
=
box_ops
.
decode_boxes
(
...
...
@@ -493,9 +492,8 @@ class DetectionGenerator(tf.keras.layers.Layer):
decoded_boxes
=
box_ops
.
clip_boxes
(
decoded_boxes
,
tf
.
expand_dims
(
image_shape
,
axis
=
1
))
decoded_boxes
=
tf
.
reshape
(
decoded_boxes
,
tf
.
stack
([
batch_size
,
num_locations
,
num_classes
-
1
,
4
],
axis
=-
1
))
decoded_boxes
=
tf
.
reshape
(
decoded_boxes
,
[
batch_size
,
num_locations
,
num_classes
-
1
,
4
])
if
not
self
.
_config_dict
[
'apply_nms'
]:
return
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment