Commit a6f36d27 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

work on cleaning up further

parent 323ea897
......@@ -56,8 +56,7 @@ class RegionSimilarityCalculator(six.with_metaclass(ABCMeta, object)):
return self._compare(boxlist1, boxlist2)
@abstractmethod
def _compare(self, boxlist1, boxlist2,
groundtruth_labels=None, predicted_labels=None):
def _compare(self, boxlist1, boxlist2):
pass
......@@ -107,6 +106,7 @@ class DETRSimilarity(RegionSimilarityCalculator):
boxlist1, boxlist2) + self.giou_weight * box_list_ops.giou(
boxlist1, boxlist2) + classification_scores
class NegSqDistSimilarity(RegionSimilarityCalculator):
"""Class to compute similarity based on the squared distance metric.
......@@ -126,6 +126,7 @@ class NegSqDistSimilarity(RegionSimilarityCalculator):
"""
return -1 * box_list_ops.sq_dist(boxlist1, boxlist2)
class IoaSimilarity(RegionSimilarityCalculator):
"""Class to compute similarity based on Intersection over Area (IOA) metric.
......
......@@ -1903,7 +1903,7 @@ class CenterNetCornerOffsetTargetAssigner(object):
class DETRTargetAssigner(object):
"""Target assigner to compute classification and regression targets."""
def __init__(self, negative_class_weight=1.0):
def __init__(self):
"""Construct Object Detection Target Assigner.
Args:
......@@ -1911,13 +1911,10 @@ class DETRTargetAssigner(object):
predicted boxes.
box_coder_instance: an object_detection.core.BoxCoder used to encode
matching groundtruth boxes with respect to predicted boxes.
negative_class_weight: classification weight to be associated to negative
boxes (default: 1.0). The weight must be in [0., 1.].
"""
self._similarity_calc = sim_calc.DETRSimilarity()
self._matcher = hungarian_matcher.HungarianBipartiteMatcher()
self._negative_class_weight = negative_class_weight
def batch_assign(self,
pred_boxes_batch,
......@@ -2055,7 +2052,7 @@ class DETRTargetAssigner(object):
cls_weights = match.gather_based_on_match(
groundtruth_weights,
ignored_value=0.,
unmatched_value=self._negative_class_weight)
unmatched_value=1)
# convert cls_weights from per-box_pred to per-class.
class_label_shape = tf.shape(cls_targets)[1:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment