Commit e09e0566 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

target assigner and similarity calculator fixes

parent d54c86de
...@@ -53,7 +53,7 @@ class RegionSimilarityCalculator(six.with_metaclass(ABCMeta, object)): ...@@ -53,7 +53,7 @@ class RegionSimilarityCalculator(six.with_metaclass(ABCMeta, object)):
a (float32) tensor of shape [N, M] with pairwise similarity score. a (float32) tensor of shape [N, M] with pairwise similarity score.
""" """
with tf.name_scope(scope, 'Compare', [boxlist1, boxlist2]) as scope: with tf.name_scope(scope, 'Compare', [boxlist1, boxlist2]) as scope:
return self._compare(boxlist1, boxlist2, groundtruth_labels, predicted_labels) return self._compare(boxlist1, boxlist2)
@abstractmethod @abstractmethod
def _compare(self, boxlist1, boxlist2, def _compare(self, boxlist1, boxlist2,
......
...@@ -103,7 +103,8 @@ class RegionSimilarityCalculatorTest(test_case.TestCase): ...@@ -103,7 +103,8 @@ class RegionSimilarityCalculatorTest(test_case.TestCase):
boxes2 = box_list.BoxList(corners2) boxes2 = box_list.BoxList(corners2)
boxes1.add_field(fields.BoxListFields.classes, groundtruth_labels) boxes1.add_field(fields.BoxListFields.classes, groundtruth_labels)
boxes2.add_field(fields.BoxListFields.classes, predicted_labels) boxes2.add_field(fields.BoxListFields.classes, predicted_labels)
detr_similarity_calculator = region_similarity_calculator.DETRSimilarity() detr_similarity_calculator = \
region_similarity_calculator.DETRSimilarity()
detr_similarity = detr_similarity_calculator.compare( detr_similarity = detr_similarity_calculator.compare(
boxes1, boxes2, None) boxes1, boxes2, None)
return detr_similarity return detr_similarity
......
...@@ -437,9 +437,7 @@ def create_target_assigner(reference, stage=None, ...@@ -437,9 +437,7 @@ def create_target_assigner(reference, stage=None,
box_coder_instance = faster_rcnn_box_coder.FasterRcnnBoxCoder() box_coder_instance = faster_rcnn_box_coder.FasterRcnnBoxCoder()
elif reference == 'DETR': elif reference == 'DETR':
similarity_calc = sim_calc.DETRSimilarity() return DETRTargetAssigner()
matcher = hungarian_matcher.HungarianBipartiteMatcher()
return DETRTargetAssigner(similarity_calc, matcher)
else: else:
raise ValueError('No valid combination of reference and stage.') raise ValueError('No valid combination of reference and stage.')
...@@ -1917,9 +1915,7 @@ class CenterNetCornerOffsetTargetAssigner(object): ...@@ -1917,9 +1915,7 @@ class CenterNetCornerOffsetTargetAssigner(object):
class DETRTargetAssigner(object): class DETRTargetAssigner(object):
"""Target assigner to compute classification and regression targets.""" """Target assigner to compute classification and regression targets."""
def __init__(self, def __init__(self, negative_class_weight=1.0):
matcher,
negative_class_weight=1.0):
"""Construct Object Detection Target Assigner. """Construct Object Detection Target Assigner.
Args: Args:
...@@ -1931,8 +1927,6 @@ class DETRTargetAssigner(object): ...@@ -1931,8 +1927,6 @@ class DETRTargetAssigner(object):
boxes (default: 1.0). The weight must be in [0., 1.]. boxes (default: 1.0). The weight must be in [0., 1.].
""" """
if not isinstance(matcher, mat.Matcher):
raise ValueError('matcher must be a Matcher')
self._similarity_calc = sim_calc.DETRSimilarity() self._similarity_calc = sim_calc.DETRSimilarity()
self._matcher = hungarian_matcher.HungarianBipartiteMatcher() self._matcher = hungarian_matcher.HungarianBipartiteMatcher()
self._negative_class_weight = negative_class_weight self._negative_class_weight = negative_class_weight
...@@ -2024,9 +2018,6 @@ class DETRTargetAssigner(object): ...@@ -2024,9 +2018,6 @@ class DETRTargetAssigner(object):
groundtruth_boxes.add_field(fields.BoxListFields.classes, groundtruth_labels) groundtruth_boxes.add_field(fields.BoxListFields.classes, groundtruth_labels)
box_preds.add_field(fields.BoxListFields.classes, class_predictions) box_preds.add_field(fields.BoxListFields.classes, class_predictions)
with tf.control_dependencies(
[unmatched_shape_assert, labels_and_box_shapes_assert]):
match_quality_matrix = self._similarity_calc.compare( match_quality_matrix = self._similarity_calc.compare(
groundtruth_boxes, groundtruth_boxes,
box_preds) box_preds)
......
...@@ -19,7 +19,6 @@ import tensorflow.compat.v1 as tf ...@@ -19,7 +19,6 @@ import tensorflow.compat.v1 as tf
from object_detection.box_coders import keypoint_box_coder from object_detection.box_coders import keypoint_box_coder
from object_detection.box_coders import mean_stddev_box_coder from object_detection.box_coders import mean_stddev_box_coder
from object_detection.box_coders import detr_box_coder
from object_detection.core import box_list from object_detection.core import box_list
from object_detection.core import region_similarity_calculator from object_detection.core import region_similarity_calculator
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
...@@ -2192,20 +2191,11 @@ class CornerOffsetTargetAssignerTest(test_case.TestCase): ...@@ -2192,20 +2191,11 @@ class CornerOffsetTargetAssignerTest(test_case.TestCase):
self.assertAllClose(foreground, np.zeros((1, 5, 5))) self.assertAllClose(foreground, np.zeros((1, 5, 5)))
if __name__ == '__main__': class DETRTargetAssignerTest(test_case.TestCase):
tf.enable_v2_behavior()
tf.test.main()
class DETRTargetAssignerTest(testcase.TestCase):
def test_assign_detr(self): def test_assign_detr(self):
def graph_fn(anchor_means, groundtruth_box_corners, def graph_fn(anchor_means, groundtruth_box_corners,
groundtruth_labels, predicted_labels): groundtruth_labels, predicted_labels):
similarity_calc = region_similarity_calculator.DETRSimilarity() detr_target_assigner = targetassigner.DETRTargetAssigner()
matcher = hungarian_matcher.HungarianBipartiteMatcher()
box_coder = detr_box_coder.DETRBoxCoder()
detr_target_assigner = target_assigner.DETRTargetAssigner(
similarity_calc, matcher, box_coder)
anchors_boxlist = box_list.BoxList(anchor_means) anchors_boxlist = box_list.BoxList(anchor_means)
groundtruth_boxlist = box_list.BoxList(groundtruth_box_corners) groundtruth_boxlist = box_list.BoxList(groundtruth_box_corners)
result = detr_target_assigner.assign( result = detr_target_assigner.assign(
...@@ -2248,3 +2238,7 @@ class DETRTargetAssignerTest(testcase.TestCase): ...@@ -2248,3 +2238,7 @@ class DETRTargetAssignerTest(testcase.TestCase):
self.assertEqual(cls_weights_out.dtype, np.float32) self.assertEqual(cls_weights_out.dtype, np.float32)
self.assertEqual(reg_targets_out.dtype, np.float32) self.assertEqual(reg_targets_out.dtype, np.float32)
self.assertEqual(reg_weights_out.dtype, np.float32) self.assertEqual(reg_weights_out.dtype, np.float32)
if __name__ == '__main__':
tf.enable_v2_behavior()
tf.test.main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment