Commit e09e0566 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

target assigner and similarity calculator fixes

parent d54c86de
...@@ -53,7 +53,7 @@ class RegionSimilarityCalculator(six.with_metaclass(ABCMeta, object)): ...@@ -53,7 +53,7 @@ class RegionSimilarityCalculator(six.with_metaclass(ABCMeta, object)):
a (float32) tensor of shape [N, M] with pairwise similarity score. a (float32) tensor of shape [N, M] with pairwise similarity score.
""" """
with tf.name_scope(scope, 'Compare', [boxlist1, boxlist2]) as scope: with tf.name_scope(scope, 'Compare', [boxlist1, boxlist2]) as scope:
return self._compare(boxlist1, boxlist2, groundtruth_labels, predicted_labels) return self._compare(boxlist1, boxlist2)
@abstractmethod @abstractmethod
def _compare(self, boxlist1, boxlist2, def _compare(self, boxlist1, boxlist2,
......
...@@ -103,7 +103,8 @@ class RegionSimilarityCalculatorTest(test_case.TestCase): ...@@ -103,7 +103,8 @@ class RegionSimilarityCalculatorTest(test_case.TestCase):
boxes2 = box_list.BoxList(corners2) boxes2 = box_list.BoxList(corners2)
boxes1.add_field(fields.BoxListFields.classes, groundtruth_labels) boxes1.add_field(fields.BoxListFields.classes, groundtruth_labels)
boxes2.add_field(fields.BoxListFields.classes, predicted_labels) boxes2.add_field(fields.BoxListFields.classes, predicted_labels)
detr_similarity_calculator = region_similarity_calculator.DETRSimilarity() detr_similarity_calculator = \
region_similarity_calculator.DETRSimilarity()
detr_similarity = detr_similarity_calculator.compare( detr_similarity = detr_similarity_calculator.compare(
boxes1, boxes2, None) boxes1, boxes2, None)
return detr_similarity return detr_similarity
......
...@@ -437,9 +437,7 @@ def create_target_assigner(reference, stage=None, ...@@ -437,9 +437,7 @@ def create_target_assigner(reference, stage=None,
box_coder_instance = faster_rcnn_box_coder.FasterRcnnBoxCoder() box_coder_instance = faster_rcnn_box_coder.FasterRcnnBoxCoder()
elif reference == 'DETR': elif reference == 'DETR':
similarity_calc = sim_calc.DETRSimilarity() return DETRTargetAssigner()
matcher = hungarian_matcher.HungarianBipartiteMatcher()
return DETRTargetAssigner(similarity_calc, matcher)
else: else:
raise ValueError('No valid combination of reference and stage.') raise ValueError('No valid combination of reference and stage.')
...@@ -1917,9 +1915,7 @@ class CenterNetCornerOffsetTargetAssigner(object): ...@@ -1917,9 +1915,7 @@ class CenterNetCornerOffsetTargetAssigner(object):
class DETRTargetAssigner(object): class DETRTargetAssigner(object):
"""Target assigner to compute classification and regression targets.""" """Target assigner to compute classification and regression targets."""
def __init__(self, def __init__(self, negative_class_weight=1.0):
matcher,
negative_class_weight=1.0):
"""Construct Object Detection Target Assigner. """Construct Object Detection Target Assigner.
Args: Args:
...@@ -1931,8 +1927,6 @@ class DETRTargetAssigner(object): ...@@ -1931,8 +1927,6 @@ class DETRTargetAssigner(object):
boxes (default: 1.0). The weight must be in [0., 1.]. boxes (default: 1.0). The weight must be in [0., 1.].
""" """
if not isinstance(matcher, mat.Matcher):
raise ValueError('matcher must be a Matcher')
self._similarity_calc = sim_calc.DETRSimilarity() self._similarity_calc = sim_calc.DETRSimilarity()
self._matcher = hungarian_matcher.HungarianBipartiteMatcher() self._matcher = hungarian_matcher.HungarianBipartiteMatcher()
self._negative_class_weight = negative_class_weight self._negative_class_weight = negative_class_weight
...@@ -2024,39 +2018,36 @@ class DETRTargetAssigner(object): ...@@ -2024,39 +2018,36 @@ class DETRTargetAssigner(object):
groundtruth_boxes.add_field(fields.BoxListFields.classes, groundtruth_labels) groundtruth_boxes.add_field(fields.BoxListFields.classes, groundtruth_labels)
box_preds.add_field(fields.BoxListFields.classes, class_predictions) box_preds.add_field(fields.BoxListFields.classes, class_predictions)
with tf.control_dependencies( match_quality_matrix = self._similarity_calc.compare(
[unmatched_shape_assert, labels_and_box_shapes_assert]): groundtruth_boxes,
box_preds)
match_quality_matrix = self._similarity_calc.compare( match = self._matcher.match(match_quality_matrix,
groundtruth_boxes, valid_rows=tf.greater(groundtruth_weights, 0))
box_preds)
match = self._matcher.match(match_quality_matrix,
valid_rows=tf.greater(groundtruth_weights, 0))
reg_targets = self._create_regression_targets(box_preds, reg_targets = self._create_regression_targets(box_preds,
groundtruth_boxes, groundtruth_boxes,
match) match)
cls_targets = match.gather_based_on_match( cls_targets = match.gather_based_on_match(
groundtruth_labels, groundtruth_labels,
unmatched_value=unmatched_class_label, unmatched_value=unmatched_class_label,
ignored_value=unmatched_class_label) ignored_value=unmatched_class_label)
reg_weights = match.gather_based_on_match(groundtruth_weights, reg_weights = match.gather_based_on_match(groundtruth_weights,
ignored_value=0., ignored_value=0.,
unmatched_value=0.) unmatched_value=0.)
cls_weights = match.gather_based_on_match( cls_weights = match.gather_based_on_match(
groundtruth_weights, groundtruth_weights,
ignored_value=0., ignored_value=0.,
unmatched_value=self._negative_class_weight) unmatched_value=self._negative_class_weight)
# convert cls_weights from per-box_pred to per-class. # convert cls_weights from per-box_pred to per-class.
class_label_shape = tf.shape(cls_targets)[1:] class_label_shape = tf.shape(cls_targets)[1:]
weights_shape = tf.shape(cls_weights) weights_shape = tf.shape(cls_weights)
weights_multiple = tf.concat( weights_multiple = tf.concat(
[tf.ones_like(weights_shape), class_label_shape], [tf.ones_like(weights_shape), class_label_shape],
axis=0) axis=0)
for _ in range(len(cls_targets.get_shape()[1:])): for _ in range(len(cls_targets.get_shape()[1:])):
cls_weights = tf.expand_dims(cls_weights, -1) cls_weights = tf.expand_dims(cls_weights, -1)
cls_weights = tf.tile(cls_weights, weights_multiple) cls_weights = tf.tile(cls_weights, weights_multiple)
num_box_preds = box_preds.num_boxes_static() num_box_preds = box_preds.num_boxes_static()
if num_box_preds is not None: if num_box_preds is not None:
......
...@@ -19,7 +19,6 @@ import tensorflow.compat.v1 as tf ...@@ -19,7 +19,6 @@ import tensorflow.compat.v1 as tf
from object_detection.box_coders import keypoint_box_coder from object_detection.box_coders import keypoint_box_coder
from object_detection.box_coders import mean_stddev_box_coder from object_detection.box_coders import mean_stddev_box_coder
from object_detection.box_coders import detr_box_coder
from object_detection.core import box_list from object_detection.core import box_list
from object_detection.core import region_similarity_calculator from object_detection.core import region_similarity_calculator
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
...@@ -2192,20 +2191,11 @@ class CornerOffsetTargetAssignerTest(test_case.TestCase): ...@@ -2192,20 +2191,11 @@ class CornerOffsetTargetAssignerTest(test_case.TestCase):
self.assertAllClose(foreground, np.zeros((1, 5, 5))) self.assertAllClose(foreground, np.zeros((1, 5, 5)))
if __name__ == '__main__': class DETRTargetAssignerTest(test_case.TestCase):
tf.enable_v2_behavior()
tf.test.main()
class DETRTargetAssignerTest(testcase.TestCase):
def test_assign_detr(self): def test_assign_detr(self):
def graph_fn(anchor_means, groundtruth_box_corners, def graph_fn(anchor_means, groundtruth_box_corners,
groundtruth_labels, predicted_labels): groundtruth_labels, predicted_labels):
similarity_calc = region_similarity_calculator.DETRSimilarity() detr_target_assigner = targetassigner.DETRTargetAssigner()
matcher = hungarian_matcher.HungarianBipartiteMatcher()
box_coder = detr_box_coder.DETRBoxCoder()
detr_target_assigner = target_assigner.DETRTargetAssigner(
similarity_calc, matcher, box_coder)
anchors_boxlist = box_list.BoxList(anchor_means) anchors_boxlist = box_list.BoxList(anchor_means)
groundtruth_boxlist = box_list.BoxList(groundtruth_box_corners) groundtruth_boxlist = box_list.BoxList(groundtruth_box_corners)
result = detr_target_assigner.assign( result = detr_target_assigner.assign(
...@@ -2247,4 +2237,8 @@ class DETRTargetAssignerTest(testcase.TestCase): ...@@ -2247,4 +2237,8 @@ class DETRTargetAssignerTest(testcase.TestCase):
self.assertEqual(cls_targets_out.dtype, np.float32) self.assertEqual(cls_targets_out.dtype, np.float32)
self.assertEqual(cls_weights_out.dtype, np.float32) self.assertEqual(cls_weights_out.dtype, np.float32)
self.assertEqual(reg_targets_out.dtype, np.float32) self.assertEqual(reg_targets_out.dtype, np.float32)
self.assertEqual(reg_weights_out.dtype, np.float32) self.assertEqual(reg_weights_out.dtype, np.float32)
\ No newline at end of file
if __name__ == '__main__':
tf.enable_v2_behavior()
tf.test.main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment