Commit d54c86de authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

make suggested fixes to target assigner and similarity calculator

parent 4f135c70
...@@ -35,8 +35,7 @@ from object_detection.core import standard_fields as fields ...@@ -35,8 +35,7 @@ from object_detection.core import standard_fields as fields
class RegionSimilarityCalculator(six.with_metaclass(ABCMeta, object)): class RegionSimilarityCalculator(six.with_metaclass(ABCMeta, object)):
"""Abstract base class for region similarity calculator.""" """Abstract base class for region similarity calculator."""
def compare(self, boxlist1, boxlist2, scope=None, def compare(self, boxlist1, boxlist2, scope=None):
groundtruth_labels=None, predicted_labels=None):
"""Computes matrix of pairwise similarity between BoxLists. """Computes matrix of pairwise similarity between BoxLists.
This op (to be overridden) computes a measure of pairwise similarity between This op (to be overridden) computes a measure of pairwise similarity between
...@@ -49,10 +48,6 @@ class RegionSimilarityCalculator(six.with_metaclass(ABCMeta, object)): ...@@ -49,10 +48,6 @@ class RegionSimilarityCalculator(six.with_metaclass(ABCMeta, object)):
boxlist1: BoxList holding N boxes. boxlist1: BoxList holding N boxes.
boxlist2: BoxList holding M boxes. boxlist2: BoxList holding M boxes.
scope: Op scope name. Defaults to 'Compare' if None. scope: Op scope name. Defaults to 'Compare' if None.
groundtruth_labels: a Tensor of shape [num_boxes, num_classes]
containing groundtruth labels.
predicted_labels: a Tensor of shape [num_boxes, num_classes]
containing predicted labels.
Returns: Returns:
a (float32) tensor of shape [N, M] with pairwise similarity score. a (float32) tensor of shape [N, M] with pairwise similarity score.
...@@ -72,17 +67,12 @@ class IouSimilarity(RegionSimilarityCalculator): ...@@ -72,17 +67,12 @@ class IouSimilarity(RegionSimilarityCalculator):
This class computes pairwise similarity between two BoxLists based on IOU. This class computes pairwise similarity between two BoxLists based on IOU.
""" """
def _compare(self, boxlist1, boxlist2, def _compare(self, boxlist1, boxlist2):
groundtruth_labels=None, predicted_labels=None):
"""Compute pairwise IOU similarity between the two BoxLists. """Compute pairwise IOU similarity between the two BoxLists.
Args: Args:
boxlist1: BoxList holding N boxes. boxlist1: BoxList holding N boxes.
boxlist2: BoxList holding M boxes. boxlist2: BoxList holding M boxes.
groundtruth_labels: a Tensor of shape [num_boxes, num_classes]
containing groundtruth labels.
predicted_labels: a Tensor of shape [num_boxes, num_classes]
containing predicted labels.
Returns: Returns:
A tensor with shape [N, M] representing pairwise iou scores. A tensor with shape [N, M] representing pairwise iou scores.
...@@ -95,25 +85,26 @@ class DETRSimilarity(RegionSimilarityCalculator): ...@@ -95,25 +85,26 @@ class DETRSimilarity(RegionSimilarityCalculator):
This class computes pairwise similarity between two BoxLists using a weighted This class computes pairwise similarity between two BoxLists using a weighted
combination of IOU, classification scores, and the L1 loss. combination of IOU, classification scores, and the L1 loss.
""" """
def __init__(self, l1_weight=5, giou_weight=2):
self.l1_weight = l1_weight
self.giou_weight = giou_weight
def _compare(self, boxlist1, boxlist2, def _compare(self, boxlist1, boxlist2):
groundtruth_labels=None, predicted_labels=None):
"""Compute pairwise IOU similarity between the two BoxLists. """Compute pairwise IOU similarity between the two BoxLists.
Args: Args:
boxlist1: BoxList holding N boxes. boxlist1: BoxList holding N groundtruth boxes.
boxlist2: BoxList holding M boxes. boxlist2: BoxList holding M predicted boxes.
groundtruth_labels: a Tensor of shape [num_boxes, num_classes]
containing groundtruth labels.
predicted_labels: a Tensor of shape [num_boxes, num_classes]
containing predicted labels.
Returns: Returns:
A tensor with shape [N, M] representing pairwise iou scores. A tensor with shape [N, M] representing pairwise iou scores.
""" """
groundtruth_labels = boxlist1.get_field(fields.BoxListFields.classes)
predicted_labels = boxlist2.get_field(fields.BoxListFields.classes)
classification_scores = tf.matmul(groundtruth_labels, classification_scores = tf.matmul(groundtruth_labels,
tf.nn.softmax(predicted_labels), transpose_b=True) tf.nn.softmax(predicted_labels), transpose_b=True)
return -5 * box_list_ops.l1(boxlist1, boxlist2) + 2 * box_list_ops.giou( return -self.l1_weight * box_list_ops.l1(
boxlist1, boxlist2) + self.giou_weight * box_list_ops.giou(
boxlist1, boxlist2) + classification_scores boxlist1, boxlist2) + classification_scores
class NegSqDistSimilarity(RegionSimilarityCalculator): class NegSqDistSimilarity(RegionSimilarityCalculator):
...@@ -123,17 +114,12 @@ class NegSqDistSimilarity(RegionSimilarityCalculator): ...@@ -123,17 +114,12 @@ class NegSqDistSimilarity(RegionSimilarityCalculator):
negative squared distance metric. negative squared distance metric.
""" """
def _compare(self, boxlist1, boxlist2, def _compare(self, boxlist1, boxlist2):
groundtruth_labels=None, predicted_labels=None):
"""Compute matrix of (negated) sq distances. """Compute matrix of (negated) sq distances.
Args: Args:
boxlist1: BoxList holding N boxes. boxlist1: BoxList holding N boxes.
boxlist2: BoxList holding M boxes. boxlist2: BoxList holding M boxes.
groundtruth_labels: a Tensor of shape [num_boxes, num_classes]
containing groundtruth labels.
predicted_labels: a Tensor of shape [num_boxes, num_classes]
containing predicted labels.
Returns: Returns:
A tensor with shape [N, M] representing negated pairwise squared distance. A tensor with shape [N, M] representing negated pairwise squared distance.
...@@ -147,17 +133,12 @@ class IoaSimilarity(RegionSimilarityCalculator): ...@@ -147,17 +133,12 @@ class IoaSimilarity(RegionSimilarityCalculator):
pairwise intersections divided by the areas of second BoxLists. pairwise intersections divided by the areas of second BoxLists.
""" """
def _compare(self, boxlist1, boxlist2, def _compare(self, boxlist1, boxlist2):
groundtruth_labels=None, predicted_labels=None):
"""Compute pairwise IOA similarity between the two BoxLists. """Compute pairwise IOA similarity between the two BoxLists.
Args: Args:
boxlist1: BoxList holding N boxes. boxlist1: BoxList holding N boxes.
boxlist2: BoxList holding M boxes. boxlist2: BoxList holding M boxes.
groundtruth_labels: a Tensor of shape [num_boxes, num_classes]
containing groundtruth labels.
predicted_labels: a Tensor of shape [num_boxes, num_classes]
containing predicted labels.
Returns: Returns:
A tensor with shape [N, M] representing pairwise IOA scores. A tensor with shape [N, M] representing pairwise IOA scores.
...@@ -184,17 +165,12 @@ class ThresholdedIouSimilarity(RegionSimilarityCalculator): ...@@ -184,17 +165,12 @@ class ThresholdedIouSimilarity(RegionSimilarityCalculator):
super(ThresholdedIouSimilarity, self).__init__() super(ThresholdedIouSimilarity, self).__init__()
self._iou_threshold = iou_threshold self._iou_threshold = iou_threshold
def _compare(self, boxlist1, boxlist2, def _compare(self, boxlist1, boxlist2):
groundtruth_labels=None, predicted_labels=None):
"""Compute pairwise IOU similarity between the two BoxLists and score. """Compute pairwise IOU similarity between the two BoxLists and score.
Args: Args:
boxlist1: BoxList holding N boxes. Must have a score field. boxlist1: BoxList holding N boxes. Must have a score field.
boxlist2: BoxList holding M boxes. boxlist2: BoxList holding M boxes.
groundtruth_labels: a Tensor of shape [num_boxes, num_classes]
containing groundtruth labels.
predicted_labels: a Tensor of shape [num_boxes, num_classes]
containing predicted labels.
Returns: Returns:
A tensor with shape [N, M] representing scores threholded by pairwise A tensor with shape [N, M] representing scores threholded by pairwise
......
...@@ -101,9 +101,11 @@ class RegionSimilarityCalculatorTest(test_case.TestCase): ...@@ -101,9 +101,11 @@ class RegionSimilarityCalculatorTest(test_case.TestCase):
predicted_labels = tf.constant([[0.0, 1000.0], [1000.0, 0.0]]) predicted_labels = tf.constant([[0.0, 1000.0], [1000.0, 0.0]])
boxes1 = box_list.BoxList(corners1) boxes1 = box_list.BoxList(corners1)
boxes2 = box_list.BoxList(corners2) boxes2 = box_list.BoxList(corners2)
boxes1.add_field(fields.BoxListFields.classes, groundtruth_labels)
boxes2.add_field(fields.BoxListFields.classes, predicted_labels)
detr_similarity_calculator = region_similarity_calculator.DETRSimilarity() detr_similarity_calculator = region_similarity_calculator.DETRSimilarity()
detr_similarity = detr_similarity_calculator.compare( detr_similarity = detr_similarity_calculator.compare(
boxes1, boxes2, None, groundtruth_labels, predicted_labels) boxes1, boxes2, None)
return detr_similarity return detr_similarity
exp_output = [[2.0, -2.0/3.0 + 1.0 - 20.0]] exp_output = [[2.0, -2.0/3.0 + 1.0 - 20.0]]
sim_output = self.execute(graph_fn, []) sim_output = self.execute(graph_fn, [])
......
...@@ -51,6 +51,7 @@ from object_detection.core import matcher as mat ...@@ -51,6 +51,7 @@ from object_detection.core import matcher as mat
from object_detection.core import region_similarity_calculator as sim_calc from object_detection.core import region_similarity_calculator as sim_calc
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
from object_detection.matchers import argmax_matcher from object_detection.matchers import argmax_matcher
from object_detection.matchers import hungarian_matcher
from object_detection.utils import shape_utils from object_detection.utils import shape_utils
from object_detection.utils import target_assigner_utils as ta_utils from object_detection.utils import target_assigner_utils as ta_utils
from object_detection.utils import tf_version from object_detection.utils import tf_version
...@@ -1917,51 +1918,44 @@ class DETRTargetAssigner(object): ...@@ -1917,51 +1918,44 @@ class DETRTargetAssigner(object):
"""Target assigner to compute classification and regression targets.""" """Target assigner to compute classification and regression targets."""
def __init__(self, def __init__(self,
similarity_calc,
matcher, matcher,
negative_class_weight=1.0): negative_class_weight=1.0):
"""Construct Object Detection Target Assigner. """Construct Object Detection Target Assigner.
Args: Args:
similarity_calc: a RegionSimilarityCalculator
matcher: an object_detection.core.Matcher used to match groundtruth to matcher: an object_detection.core.Matcher used to match groundtruth to
anchors. predicted boxes.
box_coder_instance: an object_detection.core.BoxCoder used to encode box_coder_instance: an object_detection.core.BoxCoder used to encode
matching groundtruth boxes with respect to anchors. matching groundtruth boxes with respect to predicted boxes.
negative_class_weight: classification weight to be associated to negative negative_class_weight: classification weight to be associated to negative
anchors (default: 1.0). The weight must be in [0., 1.]. boxes (default: 1.0). The weight must be in [0., 1.].
Raises:
ValueError: if similarity_calc is not a RegionSimilarityCalculator or
if matcher is not a Matcher or if box_coder is not a BoxCoder
""" """
if not isinstance(similarity_calc, sim_calc.RegionSimilarityCalculator):
raise ValueError('similarity_calc must be a RegionSimilarityCalculator')
if not isinstance(matcher, mat.Matcher): if not isinstance(matcher, mat.Matcher):
raise ValueError('matcher must be a Matcher') raise ValueError('matcher must be a Matcher')
self._similarity_calc = similarity_calc self._similarity_calc = sim_calc.DETRSimilarity()
self._matcher = matcher self._matcher = hungarian_matcher.HungarianBipartiteMatcher()
self._negative_class_weight = negative_class_weight self._negative_class_weight = negative_class_weight
def assign(self, def assign(self,
anchors, box_preds,
groundtruth_boxes, groundtruth_boxes,
groundtruth_labels=None, groundtruth_labels=None,
unmatched_class_label=None, unmatched_class_label=None,
groundtruth_weights=None, groundtruth_weights=None,
class_predictions=None): class_predictions=None):
"""Assign classification and regression targets to each anchor. """Assign classification and regression targets to each box_pred.
For a given set of anchors and groundtruth detections, match anchors For a given set of box_preds and groundtruth detections, match box_preds
to groundtruth_boxes and assign classification and regression targets to to groundtruth_boxes and assign classification and regression targets to
each anchor as well as weights based on the resulting match (specifying, each box_pred as well as weights based on the resulting match (specifying,
e.g., which anchors should not contribute to training loss). e.g., which box_preds should not contribute to training loss).
Anchors that are not matched to anything are given a classification target box_preds that are not matched to anything are given a classification target
of self._unmatched_cls_target which can be specified via the constructor. of self._unmatched_cls_target which can be specified via the constructor.
Args: Args:
anchors: a BoxList representing N anchors box_preds: a BoxList representing N box_preds
groundtruth_boxes: a BoxList representing M groundtruth boxes groundtruth_boxes: a BoxList representing M groundtruth boxes
groundtruth_labels: a tensor of shape [M, d_1, ... d_k] groundtruth_labels: a tensor of shape [M, d_1, ... d_k]
with labels for each of the ground_truth boxes. The subshape with labels for each of the ground_truth boxes. The subshape
...@@ -1970,14 +1964,14 @@ class DETRTargetAssigner(object): ...@@ -1970,14 +1964,14 @@ class DETRTargetAssigner(object):
ground_truth boxes get a positive label (of 1). ground_truth boxes get a positive label (of 1).
unmatched_class_label: a float32 tensor with shape [d_1, d_2, ..., d_k] unmatched_class_label: a float32 tensor with shape [d_1, d_2, ..., d_k]
which is consistent with the classification target for each which is consistent with the classification target for each
anchor (and can be empty for scalar targets). This shape must thus be box_pred (and can be empty for scalar targets). This shape must thus be
compatible with the groundtruth labels that are passed to the "assign" compatible with the groundtruth labels that are passed to the "assign"
function (which have shape [num_gt_boxes, d_1, d_2, ..., d_k]). function (which have shape [num_gt_boxes, d_1, d_2, ..., d_k]).
If set to None, unmatched_cls_target is set to be [0] for each anchor. If set to None, unmatched_cls_target is set to be [0] for each box_pred.
groundtruth_weights: a float tensor of shape [M] indicating the weight to groundtruth_weights: a float tensor of shape [M] indicating the weight to
assign to all anchors match to a particular groundtruth box. The weights assign to all box_preds match to a particular groundtruth box. The weights
must be in [0., 1.]. If None, all weights are set to 1. Generally no must be in [0., 1.]. If None, all weights are set to 1. Generally no
groundtruth boxes with zero weight match to any anchors as matchers are groundtruth boxes with zero weight match to any box_preds as matchers are
aware of groundtruth weights. Additionally, `cls_weights` and aware of groundtruth weights. Additionally, `cls_weights` and
`reg_weights` are calculated using groundtruth weights as an added `reg_weights` are calculated using groundtruth weights as an added
safety. safety.
...@@ -1985,27 +1979,27 @@ class DETRTargetAssigner(object): ...@@ -1985,27 +1979,27 @@ class DETRTargetAssigner(object):
to be used by certain similarity calculators. to be used by certain similarity calculators.
Returns: Returns:
cls_targets: a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k], cls_targets: a float32 tensor with shape [num_box_preds, d_1, d_2 ... d_k],
where the subshape [d_1, ..., d_k] is compatible with groundtruth_labels where the subshape [d_1, ..., d_k] is compatible with groundtruth_labels
which has shape [num_gt_boxes, d_1, d_2, ... d_k]. which has shape [num_gt_boxes, d_1, d_2, ... d_k].
cls_weights: a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k], cls_weights: a float32 tensor with shape [num_box_preds, d_1, d_2 ... d_k],
representing weights for each element in cls_targets. representing weights for each element in cls_targets.
reg_targets: a float32 tensor with shape [num_anchors, box_code_dimension] reg_targets: a float32 tensor with shape [num_box_preds, box_code_dimension]
reg_weights: a float32 tensor with shape [num_anchors] reg_weights: a float32 tensor with shape [num_box_preds]
match: an int32 tensor of shape [num_anchors] containing result of anchor match: an int32 tensor of shape [num_box_preds] containing result of box_pred
groundtruth matching. Each position in the tensor indicates an anchor groundtruth matching. Each position in the tensor indicates an box_pred
and holds the following meaning: and holds the following meaning:
(1) if match[i] >= 0, anchor i is matched with groundtruth match[i]. (1) if match[i] >= 0, box_pred i is matched with groundtruth match[i].
(2) if match[i]=-1, anchor i is marked to be background . (2) if match[i]=-1, box_pred i is marked to be background .
(3) if match[i]=-2, anchor i is ignored since it is not background and (3) if match[i]=-2, box_pred i is ignored since it is not background and
does not have sufficient overlap to call it a foreground. does not have sufficient overlap to call it a foreground.
Raises: Raises:
ValueError: if anchors or groundtruth_boxes are not of type ValueError: if box_preds or groundtruth_boxes are not of type
box_list.BoxList box_list.BoxList
""" """
if not isinstance(anchors, box_list.BoxList): if not isinstance(box_preds, box_list.BoxList):
raise ValueError('anchors must be an BoxList') raise ValueError('box_preds must be an BoxList')
if not isinstance(groundtruth_boxes, box_list.BoxList): if not isinstance(groundtruth_boxes, box_list.BoxList):
raise ValueError('groundtruth_boxes must be an BoxList') raise ValueError('groundtruth_boxes must be an BoxList')
...@@ -2017,15 +2011,6 @@ class DETRTargetAssigner(object): ...@@ -2017,15 +2011,6 @@ class DETRTargetAssigner(object):
0)) 0))
groundtruth_labels = tf.expand_dims(groundtruth_labels, -1) groundtruth_labels = tf.expand_dims(groundtruth_labels, -1)
unmatched_shape_assert = shape_utils.assert_shape_equal(
shape_utils.combined_static_and_dynamic_shape(groundtruth_labels)[1:],
shape_utils.combined_static_and_dynamic_shape(unmatched_class_label))
labels_and_box_shapes_assert = shape_utils.assert_shape_equal(
shape_utils.combined_static_and_dynamic_shape(
groundtruth_labels)[:1],
shape_utils.combined_static_and_dynamic_shape(
groundtruth_boxes.get())[:1])
if groundtruth_weights is None: if groundtruth_weights is None:
num_gt_boxes = groundtruth_boxes.num_boxes_static() num_gt_boxes = groundtruth_boxes.num_boxes_static()
if not num_gt_boxes: if not num_gt_boxes:
...@@ -2036,18 +2021,19 @@ class DETRTargetAssigner(object): ...@@ -2036,18 +2021,19 @@ class DETRTargetAssigner(object):
scores = 1 - groundtruth_labels[:, 0] scores = 1 - groundtruth_labels[:, 0]
groundtruth_boxes.add_field(fields.BoxListFields.scores, scores) groundtruth_boxes.add_field(fields.BoxListFields.scores, scores)
groundtruth_boxes.add_field(fields.BoxListFields.classes, groundtruth_labels)
box_preds.add_field(fields.BoxListFields.classes, class_predictions)
with tf.control_dependencies( with tf.control_dependencies(
[unmatched_shape_assert, labels_and_box_shapes_assert]): [unmatched_shape_assert, labels_and_box_shapes_assert]):
match_quality_matrix = self._similarity_calc.compare( match_quality_matrix = self._similarity_calc.compare(
groundtruth_boxes, groundtruth_boxes,
anchors, box_preds)
groundtruth_labels=groundtruth_labels,
predicted_labels=class_predictions)
match = self._matcher.match(match_quality_matrix, match = self._matcher.match(match_quality_matrix,
valid_rows=tf.greater(groundtruth_weights, 0)) valid_rows=tf.greater(groundtruth_weights, 0))
reg_targets = self._create_regression_targets(anchors, reg_targets = self._create_regression_targets(box_preds,
groundtruth_boxes, groundtruth_boxes,
match) match)
cls_targets = match.gather_based_on_match( cls_targets = match.gather_based_on_match(
...@@ -2062,7 +2048,7 @@ class DETRTargetAssigner(object): ...@@ -2062,7 +2048,7 @@ class DETRTargetAssigner(object):
ignored_value=0., ignored_value=0.,
unmatched_value=self._negative_class_weight) unmatched_value=self._negative_class_weight)
# convert cls_weights from per-anchor to per-class. # convert cls_weights from per-box_pred to per-class.
class_label_shape = tf.shape(cls_targets)[1:] class_label_shape = tf.shape(cls_targets)[1:]
weights_shape = tf.shape(cls_weights) weights_shape = tf.shape(cls_weights)
weights_multiple = tf.concat( weights_multiple = tf.concat(
...@@ -2072,37 +2058,37 @@ class DETRTargetAssigner(object): ...@@ -2072,37 +2058,37 @@ class DETRTargetAssigner(object):
cls_weights = tf.expand_dims(cls_weights, -1) cls_weights = tf.expand_dims(cls_weights, -1)
cls_weights = tf.tile(cls_weights, weights_multiple) cls_weights = tf.tile(cls_weights, weights_multiple)
num_anchors = anchors.num_boxes_static() num_box_preds = box_preds.num_boxes_static()
if num_anchors is not None: if num_box_preds is not None:
reg_targets = self._reset_target_shape(reg_targets, num_anchors) reg_targets = self._reset_target_shape(reg_targets, num_box_preds)
cls_targets = self._reset_target_shape(cls_targets, num_anchors) cls_targets = self._reset_target_shape(cls_targets, num_box_preds)
reg_weights = self._reset_target_shape(reg_weights, num_anchors) reg_weights = self._reset_target_shape(reg_weights, num_box_preds)
cls_weights = self._reset_target_shape(cls_weights, num_anchors) cls_weights = self._reset_target_shape(cls_weights, num_box_preds)
return (cls_targets, cls_weights, reg_targets, reg_weights, return (cls_targets, cls_weights, reg_targets, reg_weights,
match.match_results) match.match_results)
def _reset_target_shape(self, target, num_anchors): def _reset_target_shape(self, target, num_box_preds):
"""Sets the static shape of the target. """Sets the static shape of the target.
Args: Args:
target: the target tensor. Its first dimension will be overwritten. target: the target tensor. Its first dimension will be overwritten.
num_anchors: the number of anchors, which is used to override the target's num_box_preds: the number of box_preds, which is used to override the target's
first dimension. first dimension.
Returns: Returns:
A tensor with the shape info filled in. A tensor with the shape info filled in.
""" """
target_shape = target.get_shape().as_list() target_shape = target.get_shape().as_list()
target_shape[0] = num_anchors target_shape[0] = num_box_preds
target.set_shape(target_shape) target.set_shape(target_shape)
return target return target
def _create_regression_targets(self, anchors, groundtruth_boxes, match): def _create_regression_targets(self, box_preds, groundtruth_boxes, match):
"""Returns a regression target for each anchor. """Returns a regression target for each box_pred.
Args: Args:
anchors: a BoxList representing N anchors box_preds: a BoxList representing N box_preds
groundtruth_boxes: a BoxList representing M groundtruth_boxes groundtruth_boxes: a BoxList representing M groundtruth_boxes
match: a matcher.Match object match: a matcher.Match object
...@@ -2123,8 +2109,8 @@ class DETRTargetAssigner(object): ...@@ -2123,8 +2109,8 @@ class DETRTargetAssigner(object):
# Zero out the unmatched and ignored regression targets. # Zero out the unmatched and ignored regression targets.
unmatched_ignored_reg_targets = tf.tile( unmatched_ignored_reg_targets = tf.tile(
tf.constant([4 * [0]], tf.float32), [match_results_shape[0], 1]) tf.constant([4 * [0]], tf.float32), [match_results_shape[0], 1])
matched_anchors_mask = match.matched_column_indicator() matched_box_preds_mask = match.matched_column_indicator()
reg_targets = tf.where(matched_anchors_mask, reg_targets = tf.where(matched_box_preds_mask,
matched_reg_targets, matched_reg_targets,
unmatched_ignored_reg_targets) unmatched_ignored_reg_targets)
return reg_targets return reg_targets
...@@ -2204,11 +2204,11 @@ class DETRTargetAssignerTest(testcase.TestCase): ...@@ -2204,11 +2204,11 @@ class DETRTargetAssignerTest(testcase.TestCase):
similarity_calc = region_similarity_calculator.DETRSimilarity() similarity_calc = region_similarity_calculator.DETRSimilarity()
matcher = hungarian_matcher.HungarianBipartiteMatcher() matcher = hungarian_matcher.HungarianBipartiteMatcher()
box_coder = detr_box_coder.DETRBoxCoder() box_coder = detr_box_coder.DETRBoxCoder()
target_assigner = targetassigner.TargetAssigner( detr_target_assigner = target_assigner.DETRTargetAssigner(
similarity_calc, matcher, box_coder) similarity_calc, matcher, box_coder)
anchors_boxlist = box_list.BoxList(anchor_means) anchors_boxlist = box_list.BoxList(anchor_means)
groundtruth_boxlist = box_list.BoxList(groundtruth_box_corners) groundtruth_boxlist = box_list.BoxList(groundtruth_box_corners)
result = target_assigner.assign( result = detr_target_assigner.assign(
anchors_boxlist, groundtruth_boxlist, anchors_boxlist, groundtruth_boxlist,
unmatched_class_label=tf.constant( unmatched_class_label=tf.constant(
[1, 0], dtype=tf.float32), [1, 0], dtype=tf.float32),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment