Commit a6f36d27 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

work on cleaning up further

parent 323ea897
...@@ -56,8 +56,7 @@ class RegionSimilarityCalculator(six.with_metaclass(ABCMeta, object)): ...@@ -56,8 +56,7 @@ class RegionSimilarityCalculator(six.with_metaclass(ABCMeta, object)):
return self._compare(boxlist1, boxlist2) return self._compare(boxlist1, boxlist2)
@abstractmethod @abstractmethod
def _compare(self, boxlist1, boxlist2, def _compare(self, boxlist1, boxlist2):
groundtruth_labels=None, predicted_labels=None):
pass pass
...@@ -107,6 +106,7 @@ class DETRSimilarity(RegionSimilarityCalculator): ...@@ -107,6 +106,7 @@ class DETRSimilarity(RegionSimilarityCalculator):
boxlist1, boxlist2) + self.giou_weight * box_list_ops.giou( boxlist1, boxlist2) + self.giou_weight * box_list_ops.giou(
boxlist1, boxlist2) + classification_scores boxlist1, boxlist2) + classification_scores
class NegSqDistSimilarity(RegionSimilarityCalculator): class NegSqDistSimilarity(RegionSimilarityCalculator):
"""Class to compute similarity based on the squared distance metric. """Class to compute similarity based on the squared distance metric.
...@@ -126,6 +126,7 @@ class NegSqDistSimilarity(RegionSimilarityCalculator): ...@@ -126,6 +126,7 @@ class NegSqDistSimilarity(RegionSimilarityCalculator):
""" """
return -1 * box_list_ops.sq_dist(boxlist1, boxlist2) return -1 * box_list_ops.sq_dist(boxlist1, boxlist2)
class IoaSimilarity(RegionSimilarityCalculator): class IoaSimilarity(RegionSimilarityCalculator):
"""Class to compute similarity based on Intersection over Area (IOA) metric. """Class to compute similarity based on Intersection over Area (IOA) metric.
......
...@@ -1903,7 +1903,7 @@ class CenterNetCornerOffsetTargetAssigner(object): ...@@ -1903,7 +1903,7 @@ class CenterNetCornerOffsetTargetAssigner(object):
class DETRTargetAssigner(object): class DETRTargetAssigner(object):
"""Target assigner to compute classification and regression targets.""" """Target assigner to compute classification and regression targets."""
def __init__(self, negative_class_weight=1.0): def __init__(self):
"""Construct Object Detection Target Assigner. """Construct Object Detection Target Assigner.
Args: Args:
...@@ -1911,13 +1911,10 @@ class DETRTargetAssigner(object): ...@@ -1911,13 +1911,10 @@ class DETRTargetAssigner(object):
predicted boxes. predicted boxes.
box_coder_instance: an object_detection.core.BoxCoder used to encode box_coder_instance: an object_detection.core.BoxCoder used to encode
matching groundtruth boxes with respect to predicted boxes. matching groundtruth boxes with respect to predicted boxes.
negative_class_weight: classification weight to be associated to negative
boxes (default: 1.0). The weight must be in [0., 1.].
""" """
self._similarity_calc = sim_calc.DETRSimilarity() self._similarity_calc = sim_calc.DETRSimilarity()
self._matcher = hungarian_matcher.HungarianBipartiteMatcher() self._matcher = hungarian_matcher.HungarianBipartiteMatcher()
self._negative_class_weight = negative_class_weight
def batch_assign(self, def batch_assign(self,
pred_boxes_batch, pred_boxes_batch,
...@@ -2055,7 +2052,7 @@ class DETRTargetAssigner(object): ...@@ -2055,7 +2052,7 @@ class DETRTargetAssigner(object):
cls_weights = match.gather_based_on_match( cls_weights = match.gather_based_on_match(
groundtruth_weights, groundtruth_weights,
ignored_value=0., ignored_value=0.,
unmatched_value=self._negative_class_weight) unmatched_value=1)
# convert cls_weights from per-box_pred to per-class. # convert cls_weights from per-box_pred to per-class.
class_label_shape = tf.shape(cls_targets)[1:] class_label_shape = tf.shape(cls_targets)[1:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment