Commit 4f135c70 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

compress target assigner

parent 3d757d50
...@@ -2050,13 +2050,18 @@ class DETRTargetAssigner(object): ...@@ -2050,13 +2050,18 @@ class DETRTargetAssigner(object):
reg_targets = self._create_regression_targets(anchors, reg_targets = self._create_regression_targets(anchors,
groundtruth_boxes, groundtruth_boxes,
match) match)
cls_targets = self._create_classification_targets(groundtruth_labels, cls_targets = match.gather_based_on_match(
unmatched_class_label, groundtruth_labels,
match) unmatched_value=unmatched_class_label,
reg_weights = self._create_regression_weights(match, groundtruth_weights) ignored_value=unmatched_class_label)
reg_weights = match.gather_based_on_match(groundtruth_weights,
ignored_value=0.,
unmatched_value=0.)
cls_weights = match.gather_based_on_match(
groundtruth_weights,
ignored_value=0.,
unmatched_value=self._negative_class_weight)
cls_weights = self._create_classification_weights(match,
groundtruth_weights)
# convert cls_weights from per-anchor to per-class. # convert cls_weights from per-anchor to per-class.
class_label_shape = tf.shape(cls_targets)[1:] class_label_shape = tf.shape(cls_targets)[1:]
weights_shape = tf.shape(cls_weights) weights_shape = tf.shape(cls_weights)
...@@ -2117,98 +2122,9 @@ class DETRTargetAssigner(object): ...@@ -2117,98 +2122,9 @@ class DETRTargetAssigner(object):
# Zero out the unmatched and ignored regression targets. # Zero out the unmatched and ignored regression targets.
unmatched_ignored_reg_targets = tf.tile( unmatched_ignored_reg_targets = tf.tile(
self._default_regression_target(), [match_results_shape[0], 1]) tf.constant([4 * [0]], tf.float32), [match_results_shape[0], 1])
matched_anchors_mask = match.matched_column_indicator() matched_anchors_mask = match.matched_column_indicator()
reg_targets = tf.where(matched_anchors_mask, reg_targets = tf.where(matched_anchors_mask,
matched_reg_targets, matched_reg_targets,
unmatched_ignored_reg_targets) unmatched_ignored_reg_targets)
return reg_targets return reg_targets
def _default_regression_target(self):
"""Returns the default target for anchors to regress to.
Default regression targets are set to zero (though in
this implementation what these targets are set to should
not matter as the regression weight of any box set to
regress to the default target is zero).
Returns:
default_target: a float32 tensor with shape [1, box_code_dimension]
"""
return tf.constant([4 * [0]], tf.float32)
def _create_classification_targets(self, groundtruth_labels,
unmatched_class_label, match):
"""Create classification targets for each anchor.
Assign a classification target of for each anchor to the matching
groundtruth label that is provided by match. Anchors that are not matched
to anything are given the target self._unmatched_cls_target
Args:
groundtruth_labels: a tensor of shape [num_gt_boxes, d_1, ... d_k]
with labels for each of the ground_truth boxes. The subshape
[d_1, ... d_k] can be empty (corresponding to scalar labels).
unmatched_class_label: a float32 tensor with shape [d_1, d_2, ..., d_k]
which is consistent with the classification target for each
anchor (and can be empty for scalar targets). This shape must thus be
compatible with the groundtruth labels that are passed to the "assign"
function (which have shape [num_gt_boxes, d_1, d_2, ..., d_k]).
match: a matcher.Match object that provides a matching between anchors
and groundtruth boxes.
Returns:
a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k], where the
subshape [d_1, ..., d_k] is compatible with groundtruth_labels which has
shape [num_gt_boxes, d_1, d_2, ... d_k].
"""
return match.gather_based_on_match(
groundtruth_labels,
unmatched_value=unmatched_class_label,
ignored_value=unmatched_class_label)
def _create_regression_weights(self, match, groundtruth_weights):
"""Set regression weight for each anchor.
Only positive anchors are set to contribute to the regression loss, so this
method returns a weight of 1 for every positive anchor and 0 for every
negative anchor.
Args:
match: a matcher.Match object that provides a matching between anchors
and groundtruth boxes.
groundtruth_weights: a float tensor of shape [M] indicating the weight to
assign to all anchors match to a particular groundtruth box.
Returns:
a float32 tensor with shape [num_anchors] representing regression weights.
"""
return match.gather_based_on_match(
groundtruth_weights, ignored_value=0., unmatched_value=0.)
def _create_classification_weights(self,
match,
groundtruth_weights):
"""Create classification weights for each anchor.
Positive (matched) anchors are associated with a weight of
positive_class_weight and negative (unmatched) anchors are associated with
a weight of negative_class_weight. When anchors are ignored, weights are set
to zero. By default, both positive/negative weights are set to 1.0,
but they can be adjusted to handle class imbalance (which is almost always
the case in object detection).
Args:
match: a matcher.Match object that provides a matching between anchors
and groundtruth boxes.
groundtruth_weights: a float tensor of shape [M] indicating the weight to
assign to all anchors match to a particular groundtruth box.
Returns:
a float32 tensor with shape [num_anchors] representing classification
weights.
"""
return match.gather_based_on_match(
groundtruth_weights,
ignored_value=0.,
unmatched_value=self._negative_class_weight)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment