Commit 8f5ed2de authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

fixes to simplify

parent e09e0566
...@@ -108,8 +108,7 @@ class TargetAssigner(object): ...@@ -108,8 +108,7 @@ class TargetAssigner(object):
groundtruth_boxes, groundtruth_boxes,
groundtruth_labels=None, groundtruth_labels=None,
unmatched_class_label=None, unmatched_class_label=None,
groundtruth_weights=None, groundtruth_weights=None):
class_predictions=None):
"""Assign classification and regression targets to each anchor. """Assign classification and regression targets to each anchor.
For a given set of anchors and groundtruth detections, match anchors For a given set of anchors and groundtruth detections, match anchors
...@@ -451,8 +450,7 @@ def batch_assign(target_assigner, ...@@ -451,8 +450,7 @@ def batch_assign(target_assigner,
gt_box_batch, gt_box_batch,
gt_class_targets_batch, gt_class_targets_batch,
unmatched_class_label=None, unmatched_class_label=None,
gt_weights_batch=None, gt_weights_batch=None):
class_predictions=None):
"""Batched assignment of classification and regression targets. """Batched assignment of classification and regression targets.
Args: Args:
...@@ -472,8 +470,6 @@ def batch_assign(target_assigner, ...@@ -472,8 +470,6 @@ def batch_assign(target_assigner,
function (which have shape [num_gt_boxes, d_1, d_2, ..., d_k]). function (which have shape [num_gt_boxes, d_1, d_2, ..., d_k]).
gt_weights_batch: A list of 1-D tf.float32 tensors of shape gt_weights_batch: A list of 1-D tf.float32 tensors of shape
[num_boxes] containing weights for groundtruth boxes. [num_boxes] containing weights for groundtruth boxes.
class_predictions: A tensor with shape [batch_size, max_num_boxes,
d_1, d_2, ..., d_k] to be used by certain similarity calculators.
Returns: Returns:
batch_cls_targets: a tensor with shape [batch_size, num_anchors, batch_cls_targets: a tensor with shape [batch_size, num_anchors,
...@@ -514,17 +510,12 @@ def batch_assign(target_assigner, ...@@ -514,17 +510,12 @@ def batch_assign(target_assigner,
match_list = [] match_list = []
if gt_weights_batch is None: if gt_weights_batch is None:
gt_weights_batch = [None] * len(gt_class_targets_batch) gt_weights_batch = [None] * len(gt_class_targets_batch)
if class_predictions: for anchors, gt_boxes, gt_class_targets, gt_weights in zip(
class_predictions = tf.unstack(class_predictions) anchors_batch, gt_box_batch, gt_class_targets_batch, gt_weights_batch):
else:
class_predictions = [None] * len(gt_class_targets_batch)
for anchors, gt_boxes, gt_class_targets, gt_weights, class_preds in zip(
anchors_batch, gt_box_batch, gt_class_targets_batch, gt_weights_batch,
class_predictions):
(cls_targets, cls_weights, (cls_targets, cls_weights,
reg_targets, reg_weights, match) = target_assigner.assign( reg_targets, reg_weights, match) = target_assigner.assign(
anchors, gt_boxes, gt_class_targets, unmatched_class_label, anchors, gt_boxes, gt_class_targets, unmatched_class_label,
gt_weights, class_preds) gt_weights)
cls_targets_list.append(cls_targets) cls_targets_list.append(cls_targets)
cls_weights_list.append(cls_weights) cls_weights_list.append(cls_weights)
reg_targets_list.append(reg_targets) reg_targets_list.append(reg_targets)
...@@ -1931,13 +1922,74 @@ class DETRTargetAssigner(object): ...@@ -1931,13 +1922,74 @@ class DETRTargetAssigner(object):
self._matcher = hungarian_matcher.HungarianBipartiteMatcher() self._matcher = hungarian_matcher.HungarianBipartiteMatcher()
self._negative_class_weight = negative_class_weight self._negative_class_weight = negative_class_weight
def batch_assign(self,
pred_boxes_batch,
gt_box_batch,
class_predictions,
gt_class_targets_batch,
gt_weights_batch=None):
"""Batched assignment of classification and regression targets.
Args:
pred_boxes_batch: list of BoxList objects with length batch_size
representing predicted box sets.
gt_box_batch: a list of BoxList objects with length batch_size
representing groundtruth boxes for each image in the batch
class_predictions: A list of tensors with length batch_size, where each
each tensor has shape [max_num_boxes, num_classes] to be used
by certain similarity calculators.
gt_class_targets_batch: a list of tensors with length batch_size, where
each tensor has shape [num_gt_boxes_i, num_classes] and
num_gt_boxes_i is the number of boxes in the ith boxlist of
gt_box_batch.
gt_weights_batch: A list of 1-D tf.float32 tensors of shape
[num_boxes] containing weights for groundtruth boxes.
Returns:
batch_cls_targets: a tensor with shape [batch_size, num_pred_boxes,
num_classes],
batch_cls_weights: a tensor with shape [batch_size, num_pred_boxes,
num_classes],
batch_reg_targets: a tensor with shape [batch_size, num_pred_boxes,
box_code_dimension]
batch_reg_weights: a tensor with shape [batch_size, num_pred_boxes],
match: an int32 tensor of shape [batch_size, num_pred_boxes] containing
result of predicted box groundtruth matching. Each position in the
tensor indicates an predicted box and holds the following meaning:
(1) if match[x, i] >= 0, predicted box i is matched with groundtruth match[x, i].
(2) if match[x, i] = -1, predicted box i is marked to be background.
"""
cls_targets_list = []
cls_weights_list = []
reg_targets_list = []
reg_weights_list = []
if gt_weights_batch is None:
gt_weights_batch = [None] * len(gt_class_targets_batch)
class_predictions = tf.unstack(class_predictions)
for pred_boxes, gt_boxes, class_preds, gt_class_targets, gt_weights in zip(
pred_boxes_batch, gt_box_batch, class_predictions,
gt_class_targets_batch, gt_weights_batch):
(cls_targets, cls_weights,
reg_targets, reg_weights) = self.assign(
pred_boxes, gt_boxes, class_preds, gt_class_targets, gt_weights)
cls_targets_list.append(cls_targets)
cls_weights_list.append(cls_weights)
reg_targets_list.append(reg_targets)
reg_weights_list.append(reg_weights)
batch_cls_targets = tf.stack(cls_targets_list)
batch_cls_weights = tf.stack(cls_weights_list)
batch_reg_targets = tf.stack(reg_targets_list)
batch_reg_weights = tf.stack(reg_weights_list)
return (batch_cls_targets, batch_cls_weights, batch_reg_targets,
batch_reg_weights)
def assign(self, def assign(self,
box_preds, box_preds,
groundtruth_boxes, groundtruth_boxes,
groundtruth_labels=None, class_predictions,
unmatched_class_label=None, groundtruth_labels,
groundtruth_weights=None, groundtruth_weights):
class_predictions=None):
"""Assign classification and regression targets to each box_pred. """Assign classification and regression targets to each box_pred.
For a given set of box_preds and groundtruth detections, match box_preds For a given set of box_preds and groundtruth detections, match box_preds
...@@ -1951,17 +2003,13 @@ class DETRTargetAssigner(object): ...@@ -1951,17 +2003,13 @@ class DETRTargetAssigner(object):
Args: Args:
box_preds: a BoxList representing N box_preds box_preds: a BoxList representing N box_preds
groundtruth_boxes: a BoxList representing M groundtruth boxes groundtruth_boxes: a BoxList representing M groundtruth boxes
groundtruth_labels: a tensor of shape [M, d_1, ... d_k] class_predictions: A tensor with shape [max_num_boxes, num_classes]
to be used by certain similarity calculators.
groundtruth_labels: a tensor of shape [M, num_classes]
with labels for each of the ground_truth boxes. The subshape with labels for each of the ground_truth boxes. The subshape
[d_1, ... d_k] can be empty (corresponding to scalar inputs). When set [num_classes] can be empty (corresponding to scalar inputs). When set
to None, groundtruth_labels assumes a binary problem where all to None, groundtruth_labels assumes a binary problem where all
ground_truth boxes get a positive label (of 1). ground_truth boxes get a positive label (of 1).
unmatched_class_label: a float32 tensor with shape [d_1, d_2, ..., d_k]
which is consistent with the classification target for each
box_pred (and can be empty for scalar targets). This shape must thus be
compatible with the groundtruth labels that are passed to the "assign"
function (which have shape [num_gt_boxes, d_1, d_2, ..., d_k]).
If set to None, unmatched_cls_target is set to be [0] for each box_pred.
groundtruth_weights: a float tensor of shape [M] indicating the weight to groundtruth_weights: a float tensor of shape [M] indicating the weight to
assign to all box_preds match to a particular groundtruth box. The weights assign to all box_preds match to a particular groundtruth box. The weights
must be in [0., 1.]. If None, all weights are set to 1. Generally no must be in [0., 1.]. If None, all weights are set to 1. Generally no
...@@ -1969,14 +2017,12 @@ class DETRTargetAssigner(object): ...@@ -1969,14 +2017,12 @@ class DETRTargetAssigner(object):
aware of groundtruth weights. Additionally, `cls_weights` and aware of groundtruth weights. Additionally, `cls_weights` and
`reg_weights` are calculated using groundtruth weights as an added `reg_weights` are calculated using groundtruth weights as an added
safety. safety.
class_predictions: A tensor with shape [max_num_boxes, d_1, d_2, ..., d_k]
to be used by certain similarity calculators.
Returns: Returns:
cls_targets: a float32 tensor with shape [num_box_preds, d_1, d_2 ... d_k], cls_targets: a float32 tensor with shape [num_box_preds, num_classes],
where the subshape [d_1, ..., d_k] is compatible with groundtruth_labels where the subshape [num_classes] is compatible with groundtruth_labels
which has shape [num_gt_boxes, d_1, d_2, ... d_k]. which has shape [num_gt_boxes, num_classes].
cls_weights: a float32 tensor with shape [num_box_preds, d_1, d_2 ... d_k], cls_weights: a float32 tensor with shape [num_box_preds, num_classes],
representing weights for each element in cls_targets. representing weights for each element in cls_targets.
reg_targets: a float32 tensor with shape [num_box_preds, box_code_dimension] reg_targets: a float32 tensor with shape [num_box_preds, box_code_dimension]
reg_weights: a float32 tensor with shape [num_box_preds] reg_weights: a float32 tensor with shape [num_box_preds]
...@@ -1988,17 +2034,8 @@ class DETRTargetAssigner(object): ...@@ -1988,17 +2034,8 @@ class DETRTargetAssigner(object):
(3) if match[i]=-2, box_pred i is ignored since it is not background and (3) if match[i]=-2, box_pred i is ignored since it is not background and
does not have sufficient overlap to call it a foreground. does not have sufficient overlap to call it a foreground.
Raises:
ValueError: if box_preds or groundtruth_boxes are not of type
box_list.BoxList
""" """
if not isinstance(box_preds, box_list.BoxList): unmatched_class_label = tf.constant([1] + [0] * groundtruth_labels.shape[1], tf.float32)
raise ValueError('box_preds must be an BoxList')
if not isinstance(groundtruth_boxes, box_list.BoxList):
raise ValueError('groundtruth_boxes must be an BoxList')
if unmatched_class_label is None:
unmatched_class_label = tf.constant([0], tf.float32)
if groundtruth_labels is None: if groundtruth_labels is None:
groundtruth_labels = tf.ones(tf.expand_dims(groundtruth_boxes.num_boxes(), groundtruth_labels = tf.ones(tf.expand_dims(groundtruth_boxes.num_boxes(),
...@@ -2024,16 +2061,18 @@ class DETRTargetAssigner(object): ...@@ -2024,16 +2061,18 @@ class DETRTargetAssigner(object):
match = self._matcher.match(match_quality_matrix, match = self._matcher.match(match_quality_matrix,
valid_rows=tf.greater(groundtruth_weights, 0)) valid_rows=tf.greater(groundtruth_weights, 0))
reg_targets = self._create_regression_targets(box_preds, reg_targets = self._create_regression_targets(
groundtruth_boxes, box_preds,
match) groundtruth_boxes,
match)
cls_targets = match.gather_based_on_match( cls_targets = match.gather_based_on_match(
groundtruth_labels, groundtruth_labels,
unmatched_value=unmatched_class_label, unmatched_value=unmatched_class_label,
ignored_value=unmatched_class_label) ignored_value=unmatched_class_label)
reg_weights = match.gather_based_on_match(groundtruth_weights, reg_weights = match.gather_based_on_match(
ignored_value=0., groundtruth_weights,
unmatched_value=0.) ignored_value=0.,
unmatched_value=0.)
cls_weights = match.gather_based_on_match( cls_weights = match.gather_based_on_match(
groundtruth_weights, groundtruth_weights,
ignored_value=0., ignored_value=0.,
...@@ -2056,8 +2095,7 @@ class DETRTargetAssigner(object): ...@@ -2056,8 +2095,7 @@ class DETRTargetAssigner(object):
reg_weights = self._reset_target_shape(reg_weights, num_box_preds) reg_weights = self._reset_target_shape(reg_weights, num_box_preds)
cls_weights = self._reset_target_shape(cls_weights, num_box_preds) cls_weights = self._reset_target_shape(cls_weights, num_box_preds)
return (cls_targets, cls_weights, reg_targets, reg_weights, return (cls_targets, cls_weights, reg_targets, reg_weights)
match.match_results)
def _reset_target_shape(self, target, num_box_preds): def _reset_target_shape(self, target, num_box_preds):
"""Sets the static shape of the target. """Sets the static shape of the target.
......
...@@ -504,7 +504,6 @@ class BatchTargetAssignerTest(test_case.TestCase): ...@@ -504,7 +504,6 @@ class BatchTargetAssignerTest(test_case.TestCase):
return targetassigner.TargetAssigner(similarity_calc, matcher, box_coder) return targetassigner.TargetAssigner(similarity_calc, matcher, box_coder)
def test_batch_assign_targets(self): def test_batch_assign_targets(self):
def graph_fn(anchor_means, groundtruth_boxlist1, groundtruth_boxlist2): def graph_fn(anchor_means, groundtruth_boxlist1, groundtruth_boxlist2):
box_list1 = box_list.BoxList(groundtruth_boxlist1) box_list1 = box_list.BoxList(groundtruth_boxlist1)
box_list2 = box_list.BoxList(groundtruth_boxlist2) box_list2 = box_list.BoxList(groundtruth_boxlist2)
...@@ -2193,21 +2192,18 @@ class CornerOffsetTargetAssignerTest(test_case.TestCase): ...@@ -2193,21 +2192,18 @@ class CornerOffsetTargetAssignerTest(test_case.TestCase):
class DETRTargetAssignerTest(test_case.TestCase): class DETRTargetAssignerTest(test_case.TestCase):
def test_assign_detr(self): def test_assign_detr(self):
def graph_fn(anchor_means, groundtruth_box_corners, def graph_fn(pred_corners, groundtruth_box_corners,
groundtruth_labels, predicted_labels): groundtruth_labels, predicted_labels):
detr_target_assigner = targetassigner.DETRTargetAssigner() detr_target_assigner = targetassigner.DETRTargetAssigner()
anchors_boxlist = box_list.BoxList(anchor_means) pred_boxlist = box_list.BoxList(pred_corners)
groundtruth_boxlist = box_list.BoxList(groundtruth_box_corners) groundtruth_boxlist = box_list.BoxList(groundtruth_box_corners)
result = detr_target_assigner.assign( result = detr_target_assigner.assign(
anchors_boxlist, groundtruth_boxlist, pred_boxlist, groundtruth_boxlist,
unmatched_class_label=tf.constant( predicted_labels, groundtruth_labels)
[1, 0], dtype=tf.float32), (cls_targets, cls_weights, reg_targets, reg_weights) = result
groundtruth_labels=groundtruth_labels,
class_predictions=predicted_labels)
(cls_targets, cls_weights, reg_targets, reg_weights, _) = result
return (cls_targets, cls_weights, reg_targets, reg_weights) return (cls_targets, cls_weights, reg_targets, reg_weights)
anchor_means = np.array([[0.25, 0.25, 0.4, 0.2], pred_corners = np.array([[0.25, 0.25, 0.4, 0.2],
[0.5, 0.8, 1.0, 0.8], [0.5, 0.8, 1.0, 0.8],
[0.9, 0.5, 0.1, 1.0]], dtype=np.float32) [0.9, 0.5, 0.1, 1.0]], dtype=np.float32)
groundtruth_box_corners = np.array([[0.0, 0.0, 0.5, 0.5], groundtruth_box_corners = np.array([[0.0, 0.0, 0.5, 0.5],
...@@ -2227,7 +2223,51 @@ class DETRTargetAssignerTest(test_case.TestCase): ...@@ -2227,7 +2223,51 @@ class DETRTargetAssignerTest(test_case.TestCase):
(cls_targets_out, (cls_targets_out,
cls_weights_out, reg_targets_out, reg_weights_out) = self.execute( cls_weights_out, reg_targets_out, reg_weights_out) = self.execute(
graph_fn, [anchor_means, groundtruth_box_corners, graph_fn, [pred_corners, groundtruth_box_corners,
groundtruth_labels, predicted_labels])
self.assertAllClose(cls_targets_out, exp_cls_targets)
self.assertAllClose(cls_weights_out, exp_cls_weights)
self.assertAllClose(reg_targets_out, exp_reg_targets)
self.assertAllClose(reg_weights_out, exp_reg_weights)
self.assertEqual(cls_targets_out.dtype, np.float32)
self.assertEqual(cls_weights_out.dtype, np.float32)
self.assertEqual(reg_targets_out.dtype, np.float32)
self.assertEqual(reg_weights_out.dtype, np.float32)
def test_batch_assign_detr(self):
def graph_fn(pred_corners, groundtruth_box_corners,
groundtruth_labels, predicted_labels):
detr_target_assigner = targetassigner.DETRTargetAssigner()
pred_boxlist = [box_list.BoxList(pred_corners)]
groundtruth_boxlist = [box_list.BoxList(groundtruth_box_corners)]
result = detr_target_assigner.batch_assign(
pred_boxlist, groundtruth_boxlist,
[predicted_labels], [groundtruth_labels])
(cls_targets, cls_weights, reg_targets, reg_weights) = result
return (cls_targets, cls_weights, reg_targets, reg_weights)
pred_corners = np.array([[0.25, 0.25, 0.4, 0.2],
[0.5, 0.8, 1.0, 0.8],
[0.9, 0.5, 0.1, 1.0]], dtype=np.float32)
groundtruth_box_corners = np.array([[0.0, 0.0, 0.5, 0.5],
[0.5, 0.5, 0.9, 0.9]],
dtype=np.float32)
predicted_labels = np.array([[-3.0, 3.0], [2.0, 9.4], [5.0, 1.0]],
dtype=np.float32)
groundtruth_labels = np.array([[0.0, 1.0], [0.0, 1.0]],
dtype=np.float32)
exp_cls_targets = [[[0, 1], [0, 1], [1, 0]]]
exp_cls_weights = [[[1, 1], [1, 1], [1, 1]]]
exp_reg_targets = [[[0.25, 0.25, 0.5, 0.5],
[0.7, 0.7, 0.4, 0.4],
[0, 0, 0, 0]]]
exp_reg_weights = [[1, 1, 0]]
(cls_targets_out,
cls_weights_out, reg_targets_out, reg_weights_out) = self.execute(
graph_fn, [pred_corners, groundtruth_box_corners,
groundtruth_labels, predicted_labels]) groundtruth_labels, predicted_labels])
self.assertAllClose(cls_targets_out, exp_cls_targets) self.assertAllClose(cls_targets_out, exp_cls_targets)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment