Commit f56dedec authored by Yoni Ben-Meshulam's avatar Yoni Ben-Meshulam Committed by TF Object Detection Team
Browse files

Test for ignoring skip_predictions_for_unlabeled_class if groundtruth_labeled_classes is unset.

PiperOrigin-RevId: 336097508
parent 481cf8da
...@@ -255,9 +255,7 @@ class CocoDetectionEvaluationTest(tf.test.TestCase): ...@@ -255,9 +255,7 @@ class CocoDetectionEvaluationTest(tf.test.TestCase):
@unittest.skipIf(tf_version.is_tf2(), 'Only Supported in TF1.X') @unittest.skipIf(tf_version.is_tf2(), 'Only Supported in TF1.X')
class CocoEvaluationPyFuncTest(tf.test.TestCase): class CocoEvaluationPyFuncTest(tf.test.TestCase):
def testGetOneMAPWithMatchingGroundtruthAndDetections(self): def _MatchingGroundtruthAndDetections(self, coco_evaluator):
coco_evaluator = coco_evaluation.CocoDetectionEvaluator(
_get_categories_list())
image_id = tf.placeholder(tf.string, shape=()) image_id = tf.placeholder(tf.string, shape=())
groundtruth_boxes = tf.placeholder(tf.float32, shape=(None, 4)) groundtruth_boxes = tf.placeholder(tf.float32, shape=(None, 4))
groundtruth_classes = tf.placeholder(tf.float32, shape=(None)) groundtruth_classes = tf.placeholder(tf.float32, shape=(None))
...@@ -330,6 +328,20 @@ class CocoEvaluationPyFuncTest(tf.test.TestCase): ...@@ -330,6 +328,20 @@ class CocoEvaluationPyFuncTest(tf.test.TestCase):
self.assertFalse(coco_evaluator._detection_boxes_list) self.assertFalse(coco_evaluator._detection_boxes_list)
self.assertFalse(coco_evaluator._image_ids) self.assertFalse(coco_evaluator._image_ids)
def testGetOneMAPWithMatchingGroundtruthAndDetections(self):
coco_evaluator = coco_evaluation.CocoDetectionEvaluator(
_get_categories_list())
self._MatchingGroundtruthAndDetections(coco_evaluator)
# Configured to skip unmatched detector predictions with
# groundtruth_labeled_classes, but reverts to fully-labeled eval since there
# are no groundtruth_labeled_classes set.
def testGetMAPWithSkipUnmatchedPredictionsIgnoreGrountruthLabeledClasses(
self):
coco_evaluator = coco_evaluation.CocoDetectionEvaluator(
_get_categories_list(), skip_predictions_for_unlabeled_class=True)
self._MatchingGroundtruthAndDetections(coco_evaluator)
# Test skipping unmatched detector predictions with # Test skipping unmatched detector predictions with
# groundtruth_labeled_classes. # groundtruth_labeled_classes.
def testGetMAPWithSkipUnmatchedPredictions(self): def testGetMAPWithSkipUnmatchedPredictions(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment