Commit ba65cc7e authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

fix tests

parent 322d4444
...@@ -47,6 +47,10 @@ class RegionSimilarityCalculator(six.with_metaclass(ABCMeta, object)): ...@@ -47,6 +47,10 @@ class RegionSimilarityCalculator(six.with_metaclass(ABCMeta, object)):
boxlist1: BoxList holding N boxes. boxlist1: BoxList holding N boxes.
boxlist2: BoxList holding M boxes. boxlist2: BoxList holding M boxes.
scope: Op scope name. Defaults to 'Compare' if None. scope: Op scope name. Defaults to 'Compare' if None.
groundtruth_labels: a Tensor of shape [num_boxes, num_classes]
containing groundtruth labels.
predicted_labels: a Tensor of shape [num_boxes, num_classes]
containing predicted labels.
Returns: Returns:
a (float32) tensor of shape [N, M] with pairwise similarity score. a (float32) tensor of shape [N, M] with pairwise similarity score.
......
...@@ -95,18 +95,18 @@ class RegionSimilarityCalculatorTest(test_case.TestCase): ...@@ -95,18 +95,18 @@ class RegionSimilarityCalculatorTest(test_case.TestCase):
def test_detr_similarity(self): def test_detr_similarity(self):
def graph_fn(): def graph_fn():
corners1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]]) corners1 = tf.constant([[5.0, 7.0, 7.0, 9.0]])
corners2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0], corners2 = tf.constant([[5.0, 7.0, 7.0, 9.0], [5.0, 11.0, 7.0, 13.0]])
[0.0, 0.0, 20.0, 20.0]]) groundtruth_labels = tf.constant([[1.0, 0.0]])
groundtruth_labels = tf.constant([[]]) predicted_labels = tf.constant([[0.0, 1000.0], [1000.0, 0.0]])
boxes1 = box_list.BoxList(corners1) boxes1 = box_list.BoxList(corners1)
boxes2 = box_list.BoxList(corners2) boxes2 = box_list.BoxList(corners2)
iou_similarity_calculator = region_similarity_calculator.IouSimilarity() detr_similarity_calculator = region_similarity_calculator.DETRSimiliarity()
iou_similarity = iou_similarity_calculator.compare(boxes1, boxes2) detr_similarity = detr_similarity_calculator.compare(boxes1, boxes2, None, groundtruth_labels, predicted_labels)
return iou_similarity return detr_similarity
exp_output = [[2.0 / 16.0, 0, 6.0 / 400.0], [1.0 / 16.0, 0.0, 5.0 / 400.0]] exp_output = [[2.0, -2.0/3.0 + 1.0 - 20.0]]
iou_output = self.execute(graph_fn, []) sim_output = self.execute(graph_fn, [])
self.assertAllClose(iou_output, exp_output) self.assertAllClose(sim_output, exp_output)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment