"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "c1ac2bfcb6d71e8fc4c775325676dccc6853ae35"
Commit b9785623 authored by TF Object Detection Team's avatar TF Object Detection Team
Browse files

Merge pull request #9043 from kmindspark:detr-push-8

PiperOrigin-RevId: 325305594
parents 94a2302a f364daf0
...@@ -304,6 +304,50 @@ def iou(boxlist1, boxlist2, scope=None): ...@@ -304,6 +304,50 @@ def iou(boxlist1, boxlist2, scope=None):
tf.zeros_like(intersections), tf.truediv(intersections, unions)) tf.zeros_like(intersections), tf.truediv(intersections, unions))
def l1(boxlist1, boxlist2, scope=None):
"""Computes l1 loss (pairwise) between two boxlists.
Args:
boxlist1: BoxList holding N boxes
boxlist2: BoxList holding M boxes
scope: name scope.
Returns:
a tensor with shape [N, M] representing the pairwise L1 loss.
"""
with tf.name_scope(scope, 'PairwiseL1'):
ycenter1, xcenter1, h1, w1 = boxlist1.get_center_coordinates_and_sizes()
ycenter2, xcenter2, h2, w2 = boxlist2.get_center_coordinates_and_sizes()
ycenters = tf.abs(tf.expand_dims(ycenter2, axis=0) - tf.expand_dims(
tf.transpose(ycenter1), axis=1))
xcenters = tf.abs(tf.expand_dims(xcenter2, axis=0) - tf.expand_dims(
tf.transpose(xcenter1), axis=1))
heights = tf.abs(tf.expand_dims(h2, axis=0) - tf.expand_dims(
tf.transpose(h1), axis=1))
widths = tf.abs(tf.expand_dims(w2, axis=0) - tf.expand_dims(
tf.transpose(w1), axis=1))
return ycenters + xcenters + heights + widths
def giou(boxlist1, boxlist2, scope=None):
"""Computes pairwise generalized IOU between two boxlists.
Args:
boxlist1: BoxList holding N boxes
boxlist2: BoxList holding M boxes
scope: name scope.
Returns:
a tensor with shape [N, M] representing the pairwise GIoU loss.
"""
with tf.name_scope(scope, 'PairwiseGIoU'):
n = boxlist1.num_boxes()
m = boxlist2.num_boxes()
boxes1 = tf.repeat(boxlist1.get(), repeats=m, axis=0)
boxes2 = tf.tile(boxlist2.get(), multiples=[n, 1])
return tf.reshape(ops.giou(boxes1, boxes2), [n, m])
def matched_iou(boxlist1, boxlist2, scope=None): def matched_iou(boxlist1, boxlist2, scope=None):
"""Compute intersection-over-union between corresponding boxes in boxlists. """Compute intersection-over-union between corresponding boxes in boxlists.
......
...@@ -229,6 +229,31 @@ class BoxListOpsTest(test_case.TestCase): ...@@ -229,6 +229,31 @@ class BoxListOpsTest(test_case.TestCase):
iou_output = self.execute(graph_fn, []) iou_output = self.execute(graph_fn, [])
self.assertAllClose(iou_output, exp_output) self.assertAllClose(iou_output, exp_output)
def test_l1(self):
def graph_fn():
corners1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
corners2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
[0.0, 0.0, 20.0, 20.0]])
boxes1 = box_list.BoxList(corners1)
boxes2 = box_list.BoxList(corners2)
l1 = box_list_ops.l1(boxes1, boxes2)
return l1
exp_output = [[5.0, 22.5, 45.5], [8.5, 19.0, 40.0]]
l1_output = self.execute(graph_fn, [])
self.assertAllClose(l1_output, exp_output)
def test_giou(self):
def graph_fn():
corners1 = tf.constant([[5.0, 7.0, 7.0, 9.0]])
corners2 = tf.constant([[5.0, 7.0, 7.0, 9.0], [5.0, 11.0, 7.0, 13.0]])
boxes1 = box_list.BoxList(corners1)
boxes2 = box_list.BoxList(corners2)
giou = box_list_ops.giou(boxes1, boxes2)
return giou
exp_output = [[1.0, -1.0 / 3.0]]
giou_output = self.execute(graph_fn, [])
self.assertAllClose(giou_output, exp_output)
def test_matched_iou(self): def test_matched_iou(self):
def graph_fn(): def graph_fn():
corners1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]]) corners1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
......
...@@ -36,6 +36,7 @@ import tensorflow.compat.v1 as tf ...@@ -36,6 +36,7 @@ import tensorflow.compat.v1 as tf
from object_detection.core import box_list from object_detection.core import box_list
from object_detection.core import box_list_ops from object_detection.core import box_list_ops
from object_detection.utils import ops from object_detection.utils import ops
from object_detection.utils import shape_utils
class Loss(six.with_metaclass(abc.ABCMeta, object)): class Loss(six.with_metaclass(abc.ABCMeta, object)):
...@@ -210,6 +211,38 @@ class WeightedIOULocalizationLoss(Loss): ...@@ -210,6 +211,38 @@ class WeightedIOULocalizationLoss(Loss):
return tf.reshape(weights, [-1]) * per_anchor_iou_loss return tf.reshape(weights, [-1]) * per_anchor_iou_loss
class WeightedGIOULocalizationLoss(Loss):
"""GIOU localization loss function.
Sums the GIOU loss for corresponding pairs of predicted/groundtruth boxes
and for each pair assign a loss of 1 - GIOU. We then compute a weighted
sum over all pairs which is returned as the total loss.
"""
def _compute_loss(self, prediction_tensor, target_tensor, weights):
"""Compute loss function.
Args:
prediction_tensor: A float tensor of shape [batch_size, num_anchors, 4]
representing the decoded predicted boxes
target_tensor: A float tensor of shape [batch_size, num_anchors, 4]
representing the decoded target boxes
weights: a float tensor of shape [batch_size, num_anchors]
Returns:
loss: a float tensor of shape [batch_size, num_anchors] tensor
representing the value of the loss function.
"""
batch_size, num_anchors, _ = shape_utils.combined_static_and_dynamic_shape(
prediction_tensor)
predicted_boxes = tf.reshape(prediction_tensor, [-1, 4])
target_boxes = tf.reshape(target_tensor, [-1, 4])
per_anchor_iou_loss = 1 - ops.giou(predicted_boxes, target_boxes)
return tf.reshape(tf.reshape(weights, [-1]) * per_anchor_iou_loss,
[batch_size, num_anchors])
class WeightedSigmoidClassificationLoss(Loss): class WeightedSigmoidClassificationLoss(Loss):
"""Sigmoid cross entropy classification loss function.""" """Sigmoid cross entropy classification loss function."""
......
...@@ -198,6 +198,47 @@ class WeightedIOULocalizationLossTest(test_case.TestCase): ...@@ -198,6 +198,47 @@ class WeightedIOULocalizationLossTest(test_case.TestCase):
self.assertAllClose(loss_output, exp_loss) self.assertAllClose(loss_output, exp_loss)
class WeightedGIOULocalizationLossTest(test_case.TestCase):
def testReturnsCorrectLoss(self):
def graph_fn():
prediction_tensor = tf.constant([[[1.5, 0, 2.4, 1],
[0, 0, 1, 1],
[0, 0, 0, 0]]])
target_tensor = tf.constant([[[1.5, 0, 2.4, 1],
[0, 0, 1, 1],
[5, 5, 10, 10]]])
weights = [[1.0, .5, 2.0]]
loss_op = losses.WeightedGIOULocalizationLoss()
loss = loss_op(prediction_tensor,
target_tensor,
weights=weights)
loss = tf.reduce_sum(loss)
return loss
exp_loss = 3.5
loss_output = self.execute(graph_fn, [])
self.assertAllClose(loss_output, exp_loss)
def testReturnsCorrectLossWithNoLabels(self):
def graph_fn():
prediction_tensor = tf.constant([[[1.5, 0, 2.4, 1],
[0, 0, 1, 1],
[0, 0, .5, .25]]])
target_tensor = tf.constant([[[1.5, 0, 2.4, 1],
[0, 0, 1, 1],
[50, 50, 500.5, 100.25]]])
weights = [[1.0, .5, 2.0]]
losses_mask = tf.constant([False], tf.bool)
loss_op = losses.WeightedGIOULocalizationLoss()
loss = loss_op(prediction_tensor, target_tensor, weights=weights,
losses_mask=losses_mask)
loss = tf.reduce_sum(loss)
return loss
exp_loss = 0.0
loss_output = self.execute(graph_fn, [])
self.assertAllClose(loss_output, exp_loss)
class WeightedSigmoidClassificationLossTest(test_case.TestCase): class WeightedSigmoidClassificationLossTest(test_case.TestCase):
def testReturnsCorrectLoss(self): def testReturnsCorrectLoss(self):
......
...@@ -1134,3 +1134,64 @@ def decode_image(tensor_dict): ...@@ -1134,3 +1134,64 @@ def decode_image(tensor_dict):
tensor_dict[fields.InputDataFields.image], channels=3) tensor_dict[fields.InputDataFields.image], channels=3)
tensor_dict[fields.InputDataFields.image].set_shape([None, None, 3]) tensor_dict[fields.InputDataFields.image].set_shape([None, None, 3])
return tensor_dict return tensor_dict
def giou(boxes1, boxes2):
"""Computes generalized IOU between two tensors.
Each box should be represented as [ymin, xmin, ymax, xmax].
Args:
boxes1: a tensor with shape [num_boxes, 4]
boxes2: a tensor with shape [num_boxes, 4]
Returns:
a tensor of shape [num_boxes] containing GIoUs
"""
def _two_boxes_giou(boxes):
"""Compute giou between two boxes."""
boxes1, boxes2 = boxes
pred_ymin, pred_xmin, pred_ymax, pred_xmax = tf.unstack(boxes1)
gt_ymin, gt_xmin, gt_ymax, gt_xmax = tf.unstack(boxes2)
gt_area = (gt_ymax - gt_ymin) * (gt_xmax - gt_xmin)
pred_area = (pred_ymax - pred_ymin) * (pred_xmax - pred_xmin)
x1_i = tf.maximum(pred_xmin, gt_xmin)
x2_i = tf.minimum(pred_xmax, gt_xmax)
y1_i = tf.maximum(pred_ymin, gt_ymin)
y2_i = tf.minimum(pred_ymax, gt_ymax)
intersection_area = tf.maximum(0.0, y2_i - y1_i) * tf.maximum(0.0,
x2_i - x1_i)
x1_c = tf.minimum(pred_xmin, gt_xmin)
x2_c = tf.maximum(pred_xmax, gt_xmax)
y1_c = tf.minimum(pred_ymin, gt_ymin)
y2_c = tf.maximum(pred_ymax, gt_ymax)
hull_area = (y2_c - y1_c) * (x2_c - x1_c)
union_area = gt_area + pred_area - intersection_area
iou = tf.where(
tf.equal(union_area, 0.0), 0.0, intersection_area / union_area)
giou_ = iou - tf.where(hull_area > 0.0,
(hull_area - union_area) / hull_area, iou)
return giou_
return shape_utils.static_or_dynamic_map_fn(_two_boxes_giou, [boxes1, boxes2])
def center_to_corner_coordinate(input_tensor):
"""Converts input boxes from center to corner representation."""
reshaped_encodings = tf.reshape(input_tensor, [-1, 4])
ycenter = tf.gather(reshaped_encodings, [0], axis=1)
xcenter = tf.gather(reshaped_encodings, [1], axis=1)
h = tf.gather(reshaped_encodings, [2], axis=1)
w = tf.gather(reshaped_encodings, [3], axis=1)
ymin = ycenter - h / 2.
xmin = xcenter - w / 2.
ymax = ycenter + h / 2.
xmax = xcenter + w / 2.
return tf.squeeze(tf.stack([ymin, xmin, ymax, xmax], axis=1))
...@@ -1635,5 +1635,119 @@ class TestGatherWithPaddingValues(test_case.TestCase): ...@@ -1635,5 +1635,119 @@ class TestGatherWithPaddingValues(test_case.TestCase):
class TestGIoU(test_case.TestCase):
def test_giou_with_no_overlap(self):
expected_giou_tensor = [
0, -1/3, -3/4, 0, -98/100
]
def graph_fn():
boxes1 = tf.constant([[3, 4, 5, 6], [3, 3, 5, 5],
[0, 0, 0, 0], [3, 3, 5, 5],
[9, 9, 10, 10]],
dtype=tf.float32)
boxes2 = tf.constant([[3, 2, 5, 4], [3, 7, 5, 9],
[5, 5, 10, 10], [3, 5, 5, 7],
[0, 0, 1, 1]], dtype=tf.float32)
giou = ops.giou(boxes1, boxes2)
self.assertEqual(giou.dtype, tf.float32)
return giou
giou = self.execute(graph_fn, [])
self.assertAllClose(expected_giou_tensor, giou)
def test_giou_with_overlaps(self):
expected_giou_tensor = [
1/25, 1/4, 1/3, 1/7 - 2/9
]
def graph_fn():
boxes1 = tf.constant([[2, 1, 7, 6], [2, 2, 4, 4],
[2, 2, 4, 4], [2, 2, 4, 4]],
dtype=tf.float32)
boxes2 = tf.constant([[4, 3, 5, 4], [3, 3, 4, 4],
[2, 3, 4, 5], [3, 3, 5, 5]], dtype=tf.float32)
giou = ops.giou(boxes1, boxes2)
self.assertEqual(giou.dtype, tf.float32)
return giou
giou = self.execute(graph_fn, [])
self.assertAllClose(expected_giou_tensor, giou)
def test_giou_with_perfect_overlap(self):
expected_giou_tensor = [1]
def graph_fn():
boxes1 = tf.constant([[3, 3, 5, 5]], dtype=tf.float32)
boxes2 = tf.constant([[3, 3, 5, 5]], dtype=tf.float32)
giou = ops.giou(boxes1, boxes2)
self.assertEqual(giou.dtype, tf.float32)
return giou
giou = self.execute(graph_fn, [])
self.assertAllClose(expected_giou_tensor, giou)
def test_giou_with_zero_area_boxes(self):
expected_giou_tensor = [0]
def graph_fn():
boxes1 = tf.constant([[1, 1, 1, 1]], dtype=tf.float32)
boxes2 = tf.constant([[1, 1, 1, 1]], dtype=tf.float32)
giou = ops.giou(boxes1, boxes2)
self.assertEqual(giou.dtype, tf.float32)
return giou
giou = self.execute(graph_fn, [])
self.assertAllClose(expected_giou_tensor, giou)
def test_giou_different_with_l1_same(self):
expected_giou_tensor = [
2/3, 3/5
]
def graph_fn():
boxes1 = tf.constant([[3, 3, 5, 5], [3, 3, 5, 5]], dtype=tf.float32)
boxes2 = tf.constant([[3, 2.5, 5, 5.5], [3, 2.5, 5, 4.5]],
dtype=tf.float32)
giou = ops.giou(boxes1, boxes2)
self.assertEqual(giou.dtype, tf.float32)
return giou
giou = self.execute(graph_fn, [])
self.assertAllClose(expected_giou_tensor, giou)
class TestCoordinateConversion(test_case.TestCase):
def test_coord_conv(self):
expected_box_tensor = [
[0.5, 0.5, 5.5, 5.5], [2, 1, 4, 7], [0, 0, 0, 0]
]
def graph_fn():
boxes = tf.constant([[3, 3, 5, 5], [3, 4, 2, 6], [0, 0, 0, 0]],
dtype=tf.float32)
converted = ops.center_to_corner_coordinate(boxes)
self.assertEqual(converted.dtype, tf.float32)
return converted
converted = self.execute(graph_fn, [])
self.assertAllClose(expected_box_tensor, converted)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment