"...resnet50_tensorflow.git" did not exist on "f428c400b875f415f26f5c93bf565ba3cb4057f1"
Commit c95500cc authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

updates

parent ceb406bb
...@@ -339,16 +339,12 @@ def giou_loss(boxlist1, boxlist2, scope=None): ...@@ -339,16 +339,12 @@ def giou_loss(boxlist1, boxlist2, scope=None):
Returns: Returns:
a tensor with shape [N, M] representing the pairwise GIoU loss. a tensor with shape [N, M] representing the pairwise GIoU loss.
""" """
# Import done internally so this dependency is not required for
# the OD API as a whole.
import tensorflow_addons as tfa
with tf.name_scope(scope, "PairwiseGIoU"): with tf.name_scope(scope, "PairwiseGIoU"):
N = boxlist1.num_boxes() N = boxlist1.num_boxes()
M = boxlist2.num_boxes() M = boxlist2.num_boxes()
boxes1 = tf.repeat(boxlist1.get(), repeats=M, axis=0) boxes1 = tf.repeat(boxlist1.get(), repeats=M, axis=0)
boxes2 = tf.tile(boxlist2.get(), multiples=[N, 1]) boxes2 = tf.tile(boxlist2.get(), multiples=[N, 1])
return tf.reshape(tfa.losses.giou_loss(boxes1, boxes2), [N, M]) return tf.reshape(1.0 - ops.giou(boxes1, boxes2), [N, M])
def matched_iou(boxlist1, boxlist2, scope=None): def matched_iou(boxlist1, boxlist2, scope=None):
"""Compute intersection-over-union between corresponding boxes in boxlists. """Compute intersection-over-union between corresponding boxes in boxlists.
...@@ -365,16 +361,11 @@ def matched_iou(boxlist1, boxlist2, scope=None): ...@@ -365,16 +361,11 @@ def matched_iou(boxlist1, boxlist2, scope=None):
intersections = matched_intersection(boxlist1, boxlist2) intersections = matched_intersection(boxlist1, boxlist2)
areas1 = area(boxlist1) areas1 = area(boxlist1)
areas2 = area(boxlist2) areas2 = area(boxlist2)
print("AREAS AND INTERSECTION", areas1, areas2, intersections)
unions = areas1 + areas2 - intersections unions = areas1 + areas2 - intersections
return tf.where( return tf.where(
tf.equal(intersections, 0.0), tf.equal(intersections, 0.0),
tf.zeros_like(intersections), tf.truediv(intersections, unions)) tf.zeros_like(intersections), tf.truediv(intersections, unions))
def matched_giou(boxlist1, boxlist2, scope=None):
with tf.name_scope(scope, 'MatchedGIOU'):
pass
def ioa(boxlist1, boxlist2, scope=None): def ioa(boxlist1, boxlist2, scope=None):
"""Computes pairwise intersection-over-area between box collections. """Computes pairwise intersection-over-area between box collections.
......
...@@ -37,7 +37,7 @@ from object_detection.core import box_list ...@@ -37,7 +37,7 @@ from object_detection.core import box_list
from object_detection.core import box_list_ops from object_detection.core import box_list_ops
from object_detection.utils import ops from object_detection.utils import ops
from object_detection.utils import shape_utils from object_detection.utils import shape_utils
import tensorflow_addons as tfa
class Loss(six.with_metaclass(abc.ABCMeta, object)): class Loss(six.with_metaclass(abc.ABCMeta, object)):
"""Abstract base class for loss functions.""" """Abstract base class for loss functions."""
...@@ -181,6 +181,7 @@ class WeightedSmoothL1LocalizationLoss(Loss): ...@@ -181,6 +181,7 @@ class WeightedSmoothL1LocalizationLoss(Loss):
reduction=tf.losses.Reduction.NONE reduction=tf.losses.Reduction.NONE
), axis=2) ), axis=2)
class WeightedIOULocalizationLoss(Loss): class WeightedIOULocalizationLoss(Loss):
"""IOU localization loss function. """IOU localization loss function.
...@@ -209,6 +210,8 @@ class WeightedIOULocalizationLoss(Loss): ...@@ -209,6 +210,8 @@ class WeightedIOULocalizationLoss(Loss):
target_boxes) target_boxes)
return tf.reshape(weights, [-1]) * per_anchor_iou_loss return tf.reshape(weights, [-1]) * per_anchor_iou_loss
class WeightedGIOULocalizationLoss(Loss): class WeightedGIOULocalizationLoss(Loss):
"""IOU localization loss function. """IOU localization loss function.
......
...@@ -31,8 +31,6 @@ import tensorflow.compat.v1 as tf ...@@ -31,8 +31,6 @@ import tensorflow.compat.v1 as tf
from object_detection.core import box_list_ops from object_detection.core import box_list_ops
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
EPSILON = 1e-8
class RegionSimilarityCalculator(six.with_metaclass(ABCMeta, object)): class RegionSimilarityCalculator(six.with_metaclass(ABCMeta, object)):
"""Abstract base class for region similarity calculator.""" """Abstract base class for region similarity calculator."""
...@@ -109,8 +107,8 @@ class DETRSimiliarity(RegionSimilarityCalculator): ...@@ -109,8 +107,8 @@ class DETRSimiliarity(RegionSimilarityCalculator):
classification_scores = tf.matmul(groundtruth_labels, classification_scores = tf.matmul(groundtruth_labels,
tf.nn.softmax(predicted_labels), transpose_b=True) tf.nn.softmax(predicted_labels), transpose_b=True)
return -5 * box_list_ops.l1(boxlist1, boxlist2) + \ return -5 * box_list_ops.l1(boxlist1, boxlist2) + \
classification_scores + \ 2 * (1 - box_list_ops.giou_loss(boxlist1, boxlist2)) + \
2 * (1 - box_list_ops.giou_loss(boxlist1, boxlist2)) classification_scores
class NegSqDistSimilarity(RegionSimilarityCalculator): class NegSqDistSimilarity(RegionSimilarityCalculator):
"""Class to compute similarity based on the squared distance metric. """Class to compute similarity based on the squared distance metric.
......
...@@ -42,7 +42,6 @@ import tensorflow.compat.v2 as tf2 ...@@ -42,7 +42,6 @@ import tensorflow.compat.v2 as tf2
from object_detection.box_coders import faster_rcnn_box_coder from object_detection.box_coders import faster_rcnn_box_coder
from object_detection.box_coders import mean_stddev_box_coder from object_detection.box_coders import mean_stddev_box_coder
from object_detection.box_coders import detr_box_coder
from object_detection.core import box_coder from object_detection.core import box_coder
from object_detection.core import box_list from object_detection.core import box_list
from object_detection.core import box_list_ops from object_detection.core import box_list_ops
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment