Commit 9d4b102c authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

clean target assigner

parent 656ec2a6
...@@ -1953,12 +1953,7 @@ class DETRTargetAssigner(object): ...@@ -1953,12 +1953,7 @@ class DETRTargetAssigner(object):
num_classes], num_classes],
batch_reg_targets: a tensor with shape [batch_size, num_pred_boxes, batch_reg_targets: a tensor with shape [batch_size, num_pred_boxes,
box_code_dimension] box_code_dimension]
batch_reg_weights: a tensor with shape [batch_size, num_pred_boxes], batch_reg_weights: a tensor with shape [batch_size, num_pred_boxes].
match: an int32 tensor of shape [batch_size, num_pred_boxes] containing
result of predicted box groundtruth matching. Each position in the
tensor indicates an predicted box and holds the following meaning:
(1) if match[x, i] >= 0, predicted box i is matched with groundtruth match[x, i].
(2) if match[x, i] = -1, predicted box i is marked to be background.
""" """
cls_targets_list = [] cls_targets_list = []
cls_weights_list = [] cls_weights_list = []
...@@ -1989,7 +1984,7 @@ class DETRTargetAssigner(object): ...@@ -1989,7 +1984,7 @@ class DETRTargetAssigner(object):
groundtruth_boxes, groundtruth_boxes,
class_predictions, class_predictions,
groundtruth_labels, groundtruth_labels,
groundtruth_weights): groundtruth_weights=None):
"""Assign classification and regression targets to each box_pred. """Assign classification and regression targets to each box_pred.
For a given set of box_preds and groundtruth detections, match box_preds For a given set of box_preds and groundtruth detections, match box_preds
...@@ -2028,12 +2023,7 @@ class DETRTargetAssigner(object): ...@@ -2028,12 +2023,7 @@ class DETRTargetAssigner(object):
reg_weights: a float32 tensor with shape [num_box_preds] reg_weights: a float32 tensor with shape [num_box_preds]
""" """
unmatched_class_label = tf.constant([1] + [0] * groundtruth_labels.shape[1], tf.float32) unmatched_class_label = tf.constant([1] + [0] * (groundtruth_labels.shape[1] - 1), tf.float32)
if groundtruth_labels is None:
groundtruth_labels = tf.ones(tf.expand_dims(groundtruth_boxes.num_boxes(),
0))
groundtruth_labels = tf.expand_dims(groundtruth_labels, -1)
if groundtruth_weights is None: if groundtruth_weights is None:
num_gt_boxes = groundtruth_boxes.num_boxes_static() num_gt_boxes = groundtruth_boxes.num_boxes_static()
...@@ -2041,10 +2031,6 @@ class DETRTargetAssigner(object): ...@@ -2041,10 +2031,6 @@ class DETRTargetAssigner(object):
num_gt_boxes = groundtruth_boxes.num_boxes() num_gt_boxes = groundtruth_boxes.num_boxes()
groundtruth_weights = tf.ones([num_gt_boxes], dtype=tf.float32) groundtruth_weights = tf.ones([num_gt_boxes], dtype=tf.float32)
# set scores on the gt boxes
scores = 1 - groundtruth_labels[:, 0]
groundtruth_boxes.add_field(fields.BoxListFields.scores, scores)
groundtruth_boxes.add_field(fields.BoxListFields.classes, groundtruth_labels) groundtruth_boxes.add_field(fields.BoxListFields.classes, groundtruth_labels)
box_preds.add_field(fields.BoxListFields.classes, class_predictions) box_preds.add_field(fields.BoxListFields.classes, class_predictions)
...@@ -2054,10 +2040,13 @@ class DETRTargetAssigner(object): ...@@ -2054,10 +2040,13 @@ class DETRTargetAssigner(object):
match = self._matcher.match(match_quality_matrix, match = self._matcher.match(match_quality_matrix,
valid_rows=tf.greater(groundtruth_weights, 0)) valid_rows=tf.greater(groundtruth_weights, 0))
reg_targets = self._create_regression_targets( matched_gt_boxes = match.gather_based_on_match(
box_preds, groundtruth_boxes.get(),
groundtruth_boxes, unmatched_value=tf.zeros(4),
match) ignored_value=tf.zeros(4))
matched_gt_boxlist = box_list.BoxList(matched_gt_boxes)
ty, tx, th, tw = matched_gt_boxlist.get_center_coordinates_and_sizes()
reg_targets = tf.transpose(tf.stack([ty, tx, th, tw]))
cls_targets = match.gather_based_on_match( cls_targets = match.gather_based_on_match(
groundtruth_labels, groundtruth_labels,
unmatched_value=unmatched_class_label, unmatched_value=unmatched_class_label,
...@@ -2073,66 +2062,10 @@ class DETRTargetAssigner(object): ...@@ -2073,66 +2062,10 @@ class DETRTargetAssigner(object):
# convert cls_weights from per-box_pred to per-class. # convert cls_weights from per-box_pred to per-class.
class_label_shape = tf.shape(cls_targets)[1:] class_label_shape = tf.shape(cls_targets)[1:]
weights_shape = tf.shape(cls_weights)
weights_multiple = tf.concat( weights_multiple = tf.concat(
[tf.ones_like(weights_shape), class_label_shape], [tf.constant([1]), class_label_shape],
axis=0) axis=0)
for _ in range(len(cls_targets.get_shape()[1:])):
cls_weights = tf.expand_dims(cls_weights, -1) cls_weights = tf.expand_dims(cls_weights, -1)
cls_weights = tf.tile(cls_weights, weights_multiple) cls_weights = tf.tile(cls_weights, weights_multiple)
num_box_preds = box_preds.num_boxes_static()
if num_box_preds is not None:
reg_targets = self._reset_target_shape(reg_targets, num_box_preds)
cls_targets = self._reset_target_shape(cls_targets, num_box_preds)
reg_weights = self._reset_target_shape(reg_weights, num_box_preds)
cls_weights = self._reset_target_shape(cls_weights, num_box_preds)
return (cls_targets, cls_weights, reg_targets, reg_weights) return (cls_targets, cls_weights, reg_targets, reg_weights)
def _reset_target_shape(self, target, num_box_preds):
"""Sets the static shape of the target.
Args:
target: the target tensor. Its first dimension will be overwritten.
num_box_preds: the number of box_preds, which is used to override the target's
first dimension.
Returns:
A tensor with the shape info filled in.
"""
target_shape = target.get_shape().as_list()
target_shape[0] = num_box_preds
target.set_shape(target_shape)
return target
def _create_regression_targets(self, box_preds, groundtruth_boxes, match):
"""Returns a regression target for each box_pred.
Args:
box_preds: a BoxList representing N box_preds
groundtruth_boxes: a BoxList representing M groundtruth_boxes
match: a matcher.Match object
Returns:
reg_targets: a float32 tensor with shape [N, box_code_dimension]
"""
matched_gt_boxes = match.gather_based_on_match(
groundtruth_boxes.get(),
unmatched_value=tf.zeros(4),
ignored_value=tf.zeros(4))
matched_gt_boxlist = box_list.BoxList(matched_gt_boxes)
ty, tx, th, tw = matched_gt_boxlist.get_center_coordinates_and_sizes()
matched_reg_targets = tf.transpose(tf.stack([ty, tx, th, tw]))
match_results_shape = shape_utils.combined_static_and_dynamic_shape(
match.match_results)
# Zero out the unmatched and ignored regression targets.
unmatched_ignored_reg_targets = tf.tile(
tf.constant([4 * [0]], tf.float32), [match_results_shape[0], 1])
matched_box_preds_mask = match.matched_column_indicator()
reg_targets = tf.where(matched_box_preds_mask,
matched_reg_targets,
unmatched_ignored_reg_targets)
return reg_targets
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment