Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4f135c70
Commit
4f135c70
authored
Aug 11, 2020
by
Kaushik Shivakumar
Browse files
compress target assigner
parent
3d757d50
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
96 deletions
+12
-96
research/object_detection/core/target_assigner.py
research/object_detection/core/target_assigner.py
+12
-96
No files found.
research/object_detection/core/target_assigner.py
View file @
4f135c70
...
@@ -2050,13 +2050,18 @@ class DETRTargetAssigner(object):
...
@@ -2050,13 +2050,18 @@ class DETRTargetAssigner(object):
reg_targets
=
self
.
_create_regression_targets
(
anchors
,
reg_targets
=
self
.
_create_regression_targets
(
anchors
,
groundtruth_boxes
,
groundtruth_boxes
,
match
)
match
)
cls_targets
=
self
.
_create_classification_targets
(
groundtruth_labels
,
cls_targets
=
match
.
gather_based_on_match
(
unmatched_class_label
,
groundtruth_labels
,
match
)
unmatched_value
=
unmatched_class_label
,
reg_weights
=
self
.
_create_regression_weights
(
match
,
groundtruth_weights
)
ignored_value
=
unmatched_class_label
)
reg_weights
=
match
.
gather_based_on_match
(
groundtruth_weights
,
ignored_value
=
0.
,
unmatched_value
=
0.
)
cls_weights
=
match
.
gather_based_on_match
(
groundtruth_weights
,
ignored_value
=
0.
,
unmatched_value
=
self
.
_negative_class_weight
)
cls_weights
=
self
.
_create_classification_weights
(
match
,
groundtruth_weights
)
# convert cls_weights from per-anchor to per-class.
# convert cls_weights from per-anchor to per-class.
class_label_shape
=
tf
.
shape
(
cls_targets
)[
1
:]
class_label_shape
=
tf
.
shape
(
cls_targets
)[
1
:]
weights_shape
=
tf
.
shape
(
cls_weights
)
weights_shape
=
tf
.
shape
(
cls_weights
)
...
@@ -2117,98 +2122,9 @@ class DETRTargetAssigner(object):
...
@@ -2117,98 +2122,9 @@ class DETRTargetAssigner(object):
# Zero out the unmatched and ignored regression targets.
# Zero out the unmatched and ignored regression targets.
unmatched_ignored_reg_targets
=
tf
.
tile
(
unmatched_ignored_reg_targets
=
tf
.
tile
(
self
.
_default_regression_target
(
),
[
match_results_shape
[
0
],
1
])
tf
.
constant
([
4
*
[
0
]],
tf
.
float32
),
[
match_results_shape
[
0
],
1
])
matched_anchors_mask
=
match
.
matched_column_indicator
()
matched_anchors_mask
=
match
.
matched_column_indicator
()
reg_targets
=
tf
.
where
(
matched_anchors_mask
,
reg_targets
=
tf
.
where
(
matched_anchors_mask
,
matched_reg_targets
,
matched_reg_targets
,
unmatched_ignored_reg_targets
)
unmatched_ignored_reg_targets
)
return
reg_targets
return
reg_targets
def
_default_regression_target
(
self
):
"""Returns the default target for anchors to regress to.
Default regression targets are set to zero (though in
this implementation what these targets are set to should
not matter as the regression weight of any box set to
regress to the default target is zero).
Returns:
default_target: a float32 tensor with shape [1, box_code_dimension]
"""
return
tf
.
constant
([
4
*
[
0
]],
tf
.
float32
)
def
_create_classification_targets
(
self
,
groundtruth_labels
,
unmatched_class_label
,
match
):
"""Create classification targets for each anchor.
Assign a classification target of for each anchor to the matching
groundtruth label that is provided by match. Anchors that are not matched
to anything are given the target self._unmatched_cls_target
Args:
groundtruth_labels: a tensor of shape [num_gt_boxes, d_1, ... d_k]
with labels for each of the ground_truth boxes. The subshape
[d_1, ... d_k] can be empty (corresponding to scalar labels).
unmatched_class_label: a float32 tensor with shape [d_1, d_2, ..., d_k]
which is consistent with the classification target for each
anchor (and can be empty for scalar targets). This shape must thus be
compatible with the groundtruth labels that are passed to the "assign"
function (which have shape [num_gt_boxes, d_1, d_2, ..., d_k]).
match: a matcher.Match object that provides a matching between anchors
and groundtruth boxes.
Returns:
a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k], where the
subshape [d_1, ..., d_k] is compatible with groundtruth_labels which has
shape [num_gt_boxes, d_1, d_2, ... d_k].
"""
return
match
.
gather_based_on_match
(
groundtruth_labels
,
unmatched_value
=
unmatched_class_label
,
ignored_value
=
unmatched_class_label
)
def
_create_regression_weights
(
self
,
match
,
groundtruth_weights
):
"""Set regression weight for each anchor.
Only positive anchors are set to contribute to the regression loss, so this
method returns a weight of 1 for every positive anchor and 0 for every
negative anchor.
Args:
match: a matcher.Match object that provides a matching between anchors
and groundtruth boxes.
groundtruth_weights: a float tensor of shape [M] indicating the weight to
assign to all anchors match to a particular groundtruth box.
Returns:
a float32 tensor with shape [num_anchors] representing regression weights.
"""
return
match
.
gather_based_on_match
(
groundtruth_weights
,
ignored_value
=
0.
,
unmatched_value
=
0.
)
def
_create_classification_weights
(
self
,
match
,
groundtruth_weights
):
"""Create classification weights for each anchor.
Positive (matched) anchors are associated with a weight of
positive_class_weight and negative (unmatched) anchors are associated with
a weight of negative_class_weight. When anchors are ignored, weights are set
to zero. By default, both positive/negative weights are set to 1.0,
but they can be adjusted to handle class imbalance (which is almost always
the case in object detection).
Args:
match: a matcher.Match object that provides a matching between anchors
and groundtruth boxes.
groundtruth_weights: a float tensor of shape [M] indicating the weight to
assign to all anchors match to a particular groundtruth box.
Returns:
a float32 tensor with shape [num_anchors] representing classification
weights.
"""
return
match
.
gather_based_on_match
(
groundtruth_weights
,
ignored_value
=
0.
,
unmatched_value
=
self
.
_negative_class_weight
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment