Commit 13a030ed authored by TF Object Detection Team's avatar TF Object Detection Team
Browse files

Merge pull request #8962 from kmindspark:detr-push-3

PiperOrigin-RevId: 331783117
parents 50cc55f8 0bc599e7
# Lint as: python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -79,6 +80,39 @@ class IouSimilarity(RegionSimilarityCalculator):
return box_list_ops.iou(boxlist1, boxlist2)
class DETRSimilarity(RegionSimilarityCalculator):
"""Class to compute similarity for the Detection Transformer model.
This class computes pairwise DETR similarity between two BoxLists using a
weighted combination of GIOU, classification scores, and the L1 loss.
"""
def __init__(self, l1_weight=5, giou_weight=2):
super().__init__()
self.l1_weight = l1_weight
self.giou_weight = giou_weight
def _compare(self, boxlist1, boxlist2):
"""Compute pairwise DETR similarity between the two BoxLists.
Args:
boxlist1: BoxList holding N groundtruth boxes.
boxlist2: BoxList holding M predicted boxes.
Returns:
A tensor with shape [N, M] representing pairwise DETR similarity scores.
"""
groundtruth_labels = boxlist1.get_field(fields.BoxListFields.classes)
predicted_labels = boxlist2.get_field(fields.BoxListFields.classes)
classification_scores = tf.matmul(groundtruth_labels,
predicted_labels,
transpose_b=True)
loss = self.l1_weight * box_list_ops.l1(
boxlist1, boxlist2) + self.giou_weight * (1 - box_list_ops.giou(
boxlist1, boxlist2)) - classification_scores
return -loss
class NegSqDistSimilarity(RegionSimilarityCalculator):
"""Class to compute similarity based on the squared distance metric.
......
......@@ -93,6 +93,25 @@ class RegionSimilarityCalculatorTest(test_case.TestCase):
iou_output = self.execute(graph_fn, [])
self.assertAllClose(iou_output, exp_output)
def test_detr_similarity(self):
def graph_fn():
corners1 = tf.constant([[5.0, 7.0, 7.0, 9.0]])
corners2 = tf.constant([[5.0, 7.0, 7.0, 9.0], [5.0, 11.0, 7.0, 13.0]])
groundtruth_labels = tf.constant([[1.0, 0.0]])
predicted_labels = tf.constant([[0.0, 1000.0], [1000.0, 0.0]])
boxes1 = box_list.BoxList(corners1)
boxes2 = box_list.BoxList(corners2)
boxes1.add_field(fields.BoxListFields.classes, groundtruth_labels)
boxes2.add_field(fields.BoxListFields.classes, predicted_labels)
detr_similarity_calculator = \
region_similarity_calculator.DETRSimilarity()
detr_similarity = detr_similarity_calculator.compare(
boxes1, boxes2, None)
return detr_similarity
exp_output = [[0.0, -20 - 8.0/3.0 + 1000.0]]
sim_output = self.execute(graph_fn, [])
self.assertAllClose(sim_output, exp_output)
if __name__ == '__main__':
tf.test.main()
......@@ -51,6 +51,7 @@ from object_detection.core import matcher as mat
from object_detection.core import region_similarity_calculator as sim_calc
from object_detection.core import standard_fields as fields
from object_detection.matchers import argmax_matcher
from object_detection.matchers import hungarian_matcher
from object_detection.utils import shape_utils
from object_detection.utils import target_assigner_utils as ta_utils
from object_detection.utils import tf_version
......@@ -510,7 +511,8 @@ def batch_assign(target_assigner,
anchors_batch, gt_box_batch, gt_class_targets_batch, gt_weights_batch):
(cls_targets, cls_weights,
reg_targets, reg_weights, match) = target_assigner.assign(
anchors, gt_boxes, gt_class_targets, unmatched_class_label, gt_weights)
anchors, gt_boxes, gt_class_targets, unmatched_class_label,
gt_weights)
cls_targets_list.append(cls_targets)
cls_weights_list.append(cls_weights)
reg_targets_list.append(reg_targets)
......@@ -2082,3 +2084,189 @@ class CenterNetTemporalOffsetTargetAssigner(object):
batch_temporal_offsets = tf.concat(batch_temporal_offsets, axis=0)
return (batch_indices, batch_temporal_offsets, batch_weights)
class DETRTargetAssigner(object):
"""Target assigner for DETR (https://arxiv.org/abs/2005.12872).
Detection Transformer (DETR) matches predicted boxes to groundtruth directly
to determine targets instead of matching anchors to groundtruth. Hence, the
new target assigner.
"""
def __init__(self):
"""Construct Object Detection Target Assigner."""
self._similarity_calc = sim_calc.DETRSimilarity()
self._matcher = hungarian_matcher.HungarianBipartiteMatcher()
def batch_assign(self,
pred_box_batch,
gt_box_batch,
pred_class_batch,
gt_class_targets_batch,
gt_weights_batch=None,
unmatched_class_label_batch=None):
"""Batched assignment of classification and regression targets.
Args:
pred_box_batch: a tensor of shape [batch_size, num_queries, 4]
representing predicted bounding boxes.
gt_box_batch: a tensor of shape [batch_size, num_queries, 4]
representing groundtruth bounding boxes.
pred_class_batch: A list of tensors with length batch_size, where each
each tensor has shape [num_queries, num_classes] to be used
by certain similarity calculators.
gt_class_targets_batch: a list of tensors with length batch_size, where
each tensor has shape [num_gt_boxes_i, num_classes] and
num_gt_boxes_i is the number of boxes in the ith boxlist of
gt_box_batch.
gt_weights_batch: A list of 1-D tf.float32 tensors of shape
[num_boxes] containing weights for groundtruth boxes.
unmatched_class_label_batch: a float32 tensor with shape
[d_1, d_2, ..., d_k] which is consistent with the classification target
for each anchor (and can be empty for scalar targets). This shape must
thus be compatible with the `gt_class_targets_batch`.
Returns:
batch_cls_targets: a tensor with shape [batch_size, num_pred_boxes,
num_classes],
batch_cls_weights: a tensor with shape [batch_size, num_pred_boxes,
num_classes],
batch_reg_targets: a tensor with shape [batch_size, num_pred_boxes,
box_code_dimension]
batch_reg_weights: a tensor with shape [batch_size, num_pred_boxes].
"""
pred_box_batch = [
box_list.BoxList(pred_box)
for pred_box in tf.unstack(pred_box_batch)]
gt_box_batch = [
box_list.BoxList(gt_box)
for gt_box in tf.unstack(gt_box_batch)]
cls_targets_list = []
cls_weights_list = []
reg_targets_list = []
reg_weights_list = []
if gt_weights_batch is None:
gt_weights_batch = [None] * len(gt_class_targets_batch)
if unmatched_class_label_batch is None:
unmatched_class_label_batch = [None] * len(gt_class_targets_batch)
pred_class_batch = tf.unstack(pred_class_batch)
for (pred_boxes, gt_boxes, pred_class_batch, gt_class_targets, gt_weights,
unmatched_class_label) in zip(pred_box_batch, gt_box_batch,
pred_class_batch, gt_class_targets_batch,
gt_weights_batch,
unmatched_class_label_batch):
(cls_targets, cls_weights, reg_targets,
reg_weights) = self.assign(pred_boxes, gt_boxes, pred_class_batch,
gt_class_targets, gt_weights,
unmatched_class_label)
cls_targets_list.append(cls_targets)
cls_weights_list.append(cls_weights)
reg_targets_list.append(reg_targets)
reg_weights_list.append(reg_weights)
batch_cls_targets = tf.stack(cls_targets_list)
batch_cls_weights = tf.stack(cls_weights_list)
batch_reg_targets = tf.stack(reg_targets_list)
batch_reg_weights = tf.stack(reg_weights_list)
return (batch_cls_targets, batch_cls_weights, batch_reg_targets,
batch_reg_weights)
def assign(self,
pred_boxes,
gt_boxes,
pred_classes,
gt_labels,
gt_weights=None,
unmatched_class_label=None):
"""Assign classification and regression targets to each box_pred.
For a given set of pred_boxes and groundtruth detections, match pred_boxes
to gt_boxes and assign classification and regression targets to
each box_pred as well as weights based on the resulting match (specifying,
e.g., which pred_boxes should not contribute to training loss).
pred_boxes that are not matched to anything are given a classification
target of `unmatched_cls_target`.
Args:
pred_boxes: a BoxList representing N pred_boxes
gt_boxes: a BoxList representing M groundtruth boxes
pred_classes: A tensor with shape [max_num_boxes, num_classes]
to be used by certain similarity calculators.
gt_labels: a tensor of shape [M, num_classes]
with labels for each of the ground_truth boxes. The subshape
[num_classes] can be empty (corresponding to scalar inputs). When set
to None, gt_labels assumes a binary problem where all
ground_truth boxes get a positive label (of 1).
gt_weights: a float tensor of shape [M] indicating the weight to
assign to all pred_boxes match to a particular groundtruth box. The
weights must be in [0., 1.]. If None, all weights are set to 1.
Generally no groundtruth boxes with zero weight match to any pred_boxes
as matchers are aware of groundtruth weights. Additionally,
`cls_weights` and `reg_weights` are calculated using groundtruth
weights as an added safety.
unmatched_class_label: a float32 tensor with shape [d_1, d_2, ..., d_k]
which is consistent with the classification target for each
anchor (and can be empty for scalar targets). This shape must thus be
compatible with the groundtruth labels that are passed to the "assign"
function (which have shape [num_gt_boxes, d_1, d_2, ..., d_k]).
Returns:
cls_targets: a float32 tensor with shape [num_pred_boxes, num_classes],
where the subshape [num_classes] is compatible with gt_labels
which has shape [num_gt_boxes, num_classes].
cls_weights: a float32 tensor with shape [num_pred_boxes, num_classes],
representing weights for each element in cls_targets.
reg_targets: a float32 tensor with shape [num_pred_boxes,
box_code_dimension]
reg_weights: a float32 tensor with shape [num_pred_boxes]
"""
if not unmatched_class_label:
unmatched_class_label = tf.constant(
[1] + [0] * (gt_labels.shape[1] - 1), tf.float32)
if gt_weights is None:
num_gt_boxes = gt_boxes.num_boxes_static()
if not num_gt_boxes:
num_gt_boxes = gt_boxes.num_boxes()
gt_weights = tf.ones([num_gt_boxes], dtype=tf.float32)
gt_boxes.add_field(fields.BoxListFields.classes, gt_labels)
pred_boxes.add_field(fields.BoxListFields.classes, pred_classes)
match_quality_matrix = self._similarity_calc.compare(
gt_boxes,
pred_boxes)
match = self._matcher.match(match_quality_matrix,
valid_rows=tf.greater(gt_weights, 0))
matched_gt_boxes = match.gather_based_on_match(
gt_boxes.get(),
unmatched_value=tf.zeros(4),
ignored_value=tf.zeros(4))
matched_gt_boxlist = box_list.BoxList(matched_gt_boxes)
ty, tx, th, tw = matched_gt_boxlist.get_center_coordinates_and_sizes()
reg_targets = tf.transpose(tf.stack([ty, tx, th, tw]))
cls_targets = match.gather_based_on_match(
gt_labels,
unmatched_value=unmatched_class_label,
ignored_value=unmatched_class_label)
reg_weights = match.gather_based_on_match(
gt_weights,
ignored_value=0.,
unmatched_value=0.)
cls_weights = match.gather_based_on_match(
gt_weights,
ignored_value=0.,
unmatched_value=1)
# convert cls_weights from per-box_pred to per-class.
class_label_shape = tf.shape(cls_targets)[1:]
weights_multiple = tf.concat(
[tf.constant([1]), class_label_shape],
axis=0)
cls_weights = tf.expand_dims(cls_weights, -1)
cls_weights = tf.tile(cls_weights, weights_multiple)
return (cls_targets, cls_weights, reg_targets, reg_weights)
......@@ -115,6 +115,7 @@ class TargetAssignerTest(test_case.TestCase):
self.assertEqual(reg_weights_out.dtype, np.float32)
def test_assign_agnostic_with_keypoints(self):
def graph_fn(anchor_means, groundtruth_box_corners,
groundtruth_keypoints):
similarity_calc = region_similarity_calculator.IouSimilarity()
......@@ -2410,6 +2411,95 @@ class CenterNetTemporalOffsetTargetAssigner(test_case.TestCase):
np.testing.assert_array_equal(weights, [0, 0, 1, 1, 1])
class DETRTargetAssignerTest(test_case.TestCase):
def test_assign_detr(self):
def graph_fn(pred_corners, groundtruth_box_corners,
groundtruth_labels, predicted_labels):
detr_target_assigner = targetassigner.DETRTargetAssigner()
pred_boxlist = box_list.BoxList(pred_corners)
groundtruth_boxlist = box_list.BoxList(groundtruth_box_corners)
result = detr_target_assigner.assign(
pred_boxlist, groundtruth_boxlist,
predicted_labels, groundtruth_labels)
(cls_targets, cls_weights, reg_targets, reg_weights) = result
return (cls_targets, cls_weights, reg_targets, reg_weights)
pred_corners = np.array([[0.25, 0.25, 0.4, 0.2],
[0.5, 0.8, 1.0, 0.8],
[0.9, 0.5, 0.1, 1.0]], dtype=np.float32)
groundtruth_box_corners = np.array([[0.0, 0.0, 0.5, 0.5],
[0.5, 0.5, 0.9, 0.9]],
dtype=np.float32)
predicted_labels = np.array([[-3.0, 3.0], [2.0, 9.4], [5.0, 1.0]],
dtype=np.float32)
groundtruth_labels = np.array([[0.0, 1.0], [0.0, 1.0]],
dtype=np.float32)
exp_cls_targets = [[0, 1], [0, 1], [1, 0]]
exp_cls_weights = [[1, 1], [1, 1], [1, 1]]
exp_reg_targets = [[0.25, 0.25, 0.5, 0.5],
[0.7, 0.7, 0.4, 0.4],
[0, 0, 0, 0]]
exp_reg_weights = [1, 1, 0]
(cls_targets_out,
cls_weights_out, reg_targets_out, reg_weights_out) = self.execute_cpu(
graph_fn, [pred_corners, groundtruth_box_corners,
groundtruth_labels, predicted_labels])
self.assertAllClose(cls_targets_out, exp_cls_targets)
self.assertAllClose(cls_weights_out, exp_cls_weights)
self.assertAllClose(reg_targets_out, exp_reg_targets)
self.assertAllClose(reg_weights_out, exp_reg_weights)
self.assertEqual(cls_targets_out.dtype, np.float32)
self.assertEqual(cls_weights_out.dtype, np.float32)
self.assertEqual(reg_targets_out.dtype, np.float32)
self.assertEqual(reg_weights_out.dtype, np.float32)
def test_batch_assign_detr(self):
def graph_fn(pred_corners, groundtruth_box_corners,
groundtruth_labels, predicted_labels):
detr_target_assigner = targetassigner.DETRTargetAssigner()
result = detr_target_assigner.batch_assign(
pred_corners, groundtruth_box_corners,
[predicted_labels], [groundtruth_labels])
(cls_targets, cls_weights, reg_targets, reg_weights) = result
return (cls_targets, cls_weights, reg_targets, reg_weights)
pred_corners = np.array([[[0.25, 0.25, 0.4, 0.2],
[0.5, 0.8, 1.0, 0.8],
[0.9, 0.5, 0.1, 1.0]]], dtype=np.float32)
groundtruth_box_corners = np.array([[[0.0, 0.0, 0.5, 0.5],
[0.5, 0.5, 0.9, 0.9]]],
dtype=np.float32)
predicted_labels = np.array([[-3.0, 3.0], [2.0, 9.4], [5.0, 1.0]],
dtype=np.float32)
groundtruth_labels = np.array([[0.0, 1.0], [0.0, 1.0]],
dtype=np.float32)
exp_cls_targets = [[[0, 1], [0, 1], [1, 0]]]
exp_cls_weights = [[[1, 1], [1, 1], [1, 1]]]
exp_reg_targets = [[[0.25, 0.25, 0.5, 0.5],
[0.7, 0.7, 0.4, 0.4],
[0, 0, 0, 0]]]
exp_reg_weights = [[1, 1, 0]]
(cls_targets_out,
cls_weights_out, reg_targets_out, reg_weights_out) = self.execute_cpu(
graph_fn, [pred_corners, groundtruth_box_corners,
groundtruth_labels, predicted_labels])
self.assertAllClose(cls_targets_out, exp_cls_targets)
self.assertAllClose(cls_weights_out, exp_cls_weights)
self.assertAllClose(reg_targets_out, exp_reg_targets)
self.assertAllClose(reg_weights_out, exp_reg_weights)
self.assertEqual(cls_targets_out.dtype, np.float32)
self.assertEqual(cls_weights_out.dtype, np.float32)
self.assertEqual(reg_targets_out.dtype, np.float32)
self.assertEqual(reg_weights_out.dtype, np.float32)
if __name__ == '__main__':
tf.enable_v2_behavior()
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment