Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9d4b102c
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "f006521b61b5e3cbb1c64bd5a83d8e9a06e579a4"
Commit
9d4b102c
authored
Aug 16, 2020
by
Kaushik Shivakumar
Browse files
clean target assigner
parent
656ec2a6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
79 deletions
+12
-79
research/object_detection/core/target_assigner.py
research/object_detection/core/target_assigner.py
+12
-79
No files found.
research/object_detection/core/target_assigner.py
View file @
9d4b102c
...
@@ -1953,12 +1953,7 @@ class DETRTargetAssigner(object):
...
@@ -1953,12 +1953,7 @@ class DETRTargetAssigner(object):
num_classes],
num_classes],
batch_reg_targets: a tensor with shape [batch_size, num_pred_boxes,
batch_reg_targets: a tensor with shape [batch_size, num_pred_boxes,
box_code_dimension]
box_code_dimension]
batch_reg_weights: a tensor with shape [batch_size, num_pred_boxes],
batch_reg_weights: a tensor with shape [batch_size, num_pred_boxes].
match: an int32 tensor of shape [batch_size, num_pred_boxes] containing
result of predicted box groundtruth matching. Each position in the
tensor indicates an predicted box and holds the following meaning:
(1) if match[x, i] >= 0, predicted box i is matched with groundtruth match[x, i].
(2) if match[x, i] = -1, predicted box i is marked to be background.
"""
"""
cls_targets_list
=
[]
cls_targets_list
=
[]
cls_weights_list
=
[]
cls_weights_list
=
[]
...
@@ -1989,7 +1984,7 @@ class DETRTargetAssigner(object):
...
@@ -1989,7 +1984,7 @@ class DETRTargetAssigner(object):
groundtruth_boxes
,
groundtruth_boxes
,
class_predictions
,
class_predictions
,
groundtruth_labels
,
groundtruth_labels
,
groundtruth_weights
):
groundtruth_weights
=
None
):
"""Assign classification and regression targets to each box_pred.
"""Assign classification and regression targets to each box_pred.
For a given set of box_preds and groundtruth detections, match box_preds
For a given set of box_preds and groundtruth detections, match box_preds
...
@@ -2028,12 +2023,7 @@ class DETRTargetAssigner(object):
...
@@ -2028,12 +2023,7 @@ class DETRTargetAssigner(object):
reg_weights: a float32 tensor with shape [num_box_preds]
reg_weights: a float32 tensor with shape [num_box_preds]
"""
"""
unmatched_class_label
=
tf
.
constant
([
1
]
+
[
0
]
*
groundtruth_labels
.
shape
[
1
],
tf
.
float32
)
unmatched_class_label
=
tf
.
constant
([
1
]
+
[
0
]
*
(
groundtruth_labels
.
shape
[
1
]
-
1
),
tf
.
float32
)
if
groundtruth_labels
is
None
:
groundtruth_labels
=
tf
.
ones
(
tf
.
expand_dims
(
groundtruth_boxes
.
num_boxes
(),
0
))
groundtruth_labels
=
tf
.
expand_dims
(
groundtruth_labels
,
-
1
)
if
groundtruth_weights
is
None
:
if
groundtruth_weights
is
None
:
num_gt_boxes
=
groundtruth_boxes
.
num_boxes_static
()
num_gt_boxes
=
groundtruth_boxes
.
num_boxes_static
()
...
@@ -2041,10 +2031,6 @@ class DETRTargetAssigner(object):
...
@@ -2041,10 +2031,6 @@ class DETRTargetAssigner(object):
num_gt_boxes
=
groundtruth_boxes
.
num_boxes
()
num_gt_boxes
=
groundtruth_boxes
.
num_boxes
()
groundtruth_weights
=
tf
.
ones
([
num_gt_boxes
],
dtype
=
tf
.
float32
)
groundtruth_weights
=
tf
.
ones
([
num_gt_boxes
],
dtype
=
tf
.
float32
)
# set scores on the gt boxes
scores
=
1
-
groundtruth_labels
[:,
0
]
groundtruth_boxes
.
add_field
(
fields
.
BoxListFields
.
scores
,
scores
)
groundtruth_boxes
.
add_field
(
fields
.
BoxListFields
.
classes
,
groundtruth_labels
)
groundtruth_boxes
.
add_field
(
fields
.
BoxListFields
.
classes
,
groundtruth_labels
)
box_preds
.
add_field
(
fields
.
BoxListFields
.
classes
,
class_predictions
)
box_preds
.
add_field
(
fields
.
BoxListFields
.
classes
,
class_predictions
)
...
@@ -2054,10 +2040,13 @@ class DETRTargetAssigner(object):
...
@@ -2054,10 +2040,13 @@ class DETRTargetAssigner(object):
match
=
self
.
_matcher
.
match
(
match_quality_matrix
,
match
=
self
.
_matcher
.
match
(
match_quality_matrix
,
valid_rows
=
tf
.
greater
(
groundtruth_weights
,
0
))
valid_rows
=
tf
.
greater
(
groundtruth_weights
,
0
))
reg_targets
=
self
.
_create_regression_targets
(
matched_gt_boxes
=
match
.
gather_based_on_match
(
box_preds
,
groundtruth_boxes
.
get
(),
groundtruth_boxes
,
unmatched_value
=
tf
.
zeros
(
4
),
match
)
ignored_value
=
tf
.
zeros
(
4
))
matched_gt_boxlist
=
box_list
.
BoxList
(
matched_gt_boxes
)
ty
,
tx
,
th
,
tw
=
matched_gt_boxlist
.
get_center_coordinates_and_sizes
()
reg_targets
=
tf
.
transpose
(
tf
.
stack
([
ty
,
tx
,
th
,
tw
]))
cls_targets
=
match
.
gather_based_on_match
(
cls_targets
=
match
.
gather_based_on_match
(
groundtruth_labels
,
groundtruth_labels
,
unmatched_value
=
unmatched_class_label
,
unmatched_value
=
unmatched_class_label
,
...
@@ -2073,66 +2062,10 @@ class DETRTargetAssigner(object):
...
@@ -2073,66 +2062,10 @@ class DETRTargetAssigner(object):
# convert cls_weights from per-box_pred to per-class.
# convert cls_weights from per-box_pred to per-class.
class_label_shape
=
tf
.
shape
(
cls_targets
)[
1
:]
class_label_shape
=
tf
.
shape
(
cls_targets
)[
1
:]
weights_shape
=
tf
.
shape
(
cls_weights
)
weights_multiple
=
tf
.
concat
(
weights_multiple
=
tf
.
concat
(
[
tf
.
on
es_like
(
weights_shape
),
class_label_shape
],
[
tf
.
c
on
stant
([
1
]
),
class_label_shape
],
axis
=
0
)
axis
=
0
)
for
_
in
range
(
len
(
cls_targets
.
get_shape
()[
1
:])):
cls_weights
=
tf
.
expand_dims
(
cls_weights
,
-
1
)
cls_weights
=
tf
.
expand_dims
(
cls_weights
,
-
1
)
cls_weights
=
tf
.
tile
(
cls_weights
,
weights_multiple
)
cls_weights
=
tf
.
tile
(
cls_weights
,
weights_multiple
)
num_box_preds
=
box_preds
.
num_boxes_static
()
if
num_box_preds
is
not
None
:
reg_targets
=
self
.
_reset_target_shape
(
reg_targets
,
num_box_preds
)
cls_targets
=
self
.
_reset_target_shape
(
cls_targets
,
num_box_preds
)
reg_weights
=
self
.
_reset_target_shape
(
reg_weights
,
num_box_preds
)
cls_weights
=
self
.
_reset_target_shape
(
cls_weights
,
num_box_preds
)
return
(
cls_targets
,
cls_weights
,
reg_targets
,
reg_weights
)
return
(
cls_targets
,
cls_weights
,
reg_targets
,
reg_weights
)
def
_reset_target_shape
(
self
,
target
,
num_box_preds
):
"""Sets the static shape of the target.
Args:
target: the target tensor. Its first dimension will be overwritten.
num_box_preds: the number of box_preds, which is used to override the target's
first dimension.
Returns:
A tensor with the shape info filled in.
"""
target_shape
=
target
.
get_shape
().
as_list
()
target_shape
[
0
]
=
num_box_preds
target
.
set_shape
(
target_shape
)
return
target
def
_create_regression_targets
(
self
,
box_preds
,
groundtruth_boxes
,
match
):
"""Returns a regression target for each box_pred.
Args:
box_preds: a BoxList representing N box_preds
groundtruth_boxes: a BoxList representing M groundtruth_boxes
match: a matcher.Match object
Returns:
reg_targets: a float32 tensor with shape [N, box_code_dimension]
"""
matched_gt_boxes
=
match
.
gather_based_on_match
(
groundtruth_boxes
.
get
(),
unmatched_value
=
tf
.
zeros
(
4
),
ignored_value
=
tf
.
zeros
(
4
))
matched_gt_boxlist
=
box_list
.
BoxList
(
matched_gt_boxes
)
ty
,
tx
,
th
,
tw
=
matched_gt_boxlist
.
get_center_coordinates_and_sizes
()
matched_reg_targets
=
tf
.
transpose
(
tf
.
stack
([
ty
,
tx
,
th
,
tw
]))
match_results_shape
=
shape_utils
.
combined_static_and_dynamic_shape
(
match
.
match_results
)
# Zero out the unmatched and ignored regression targets.
unmatched_ignored_reg_targets
=
tf
.
tile
(
tf
.
constant
([
4
*
[
0
]],
tf
.
float32
),
[
match_results_shape
[
0
],
1
])
matched_box_preds_mask
=
match
.
matched_column_indicator
()
reg_targets
=
tf
.
where
(
matched_box_preds_mask
,
matched_reg_targets
,
unmatched_ignored_reg_targets
)
return
reg_targets
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment