Commit 4022aae5 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

self code-review to clean up

parent f0bc684c
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
......@@ -225,7 +225,7 @@ class WeightedGIOULocalizationLoss(Loss):
Args:
prediction_tensor: A float tensor of shape [batch_size, num_anchors, 4]
representing the decoded predicted boxes
representing the predicted boxes in the form
target_tensor: A float tensor of shape [batch_size, num_anchors, 4]
representing the decoded target boxes
weights: a float tensor of shape [batch_size, num_anchors]
......@@ -236,10 +236,8 @@ class WeightedGIOULocalizationLoss(Loss):
"""
batch_size, num_anchors, _ = shape_utils.combined_static_and_dynamic_shape(
prediction_tensor)
predicted_boxes = ops.cy_cx_h_w_to_ymin_xmin_ymax_xmax_coords(
tf.reshape(prediction_tensor, [-1, 4]))
target_boxes = ops.cy_cx_h_w_to_ymin_xmin_ymax_xmax_coords(
tf.reshape(target_tensor, [-1, 4]))
predicted_boxes = tf.reshape(prediction_tensor, [-1, 4])
target_boxes = tf.reshape(target_tensor, [-1, 4])
per_anchor_iou_loss = 1 - ops.giou(predicted_boxes, target_boxes)
return tf.reshape(tf.reshape(weights, [-1]) * per_anchor_iou_loss,
......
......@@ -28,6 +28,7 @@ from object_detection.core import box_list
from object_detection.core import losses
from object_detection.core import matcher
from object_detection.utils import test_case
from object_detection.utils import ops
class WeightedL2LocalizationLossTest(test_case.TestCase):
......@@ -206,10 +207,12 @@ class WeightedGIOULocalizationLossTest(test_case.TestCase):
[0, 0, 0, 0]]])
target_tensor = tf.constant([[[1.5, 0, 2.4, 1],
[0, 0, 1, 1],
[7.5, 7.5, 2.5, 2.5]]])
[5, 5, 10, 10]]])
weights = [[1.0, .5, 2.0]]
loss_op = losses.WeightedGIOULocalizationLoss()
loss = loss_op(prediction_tensor, target_tensor, weights=weights)
loss = loss_op(prediction_tensor,
target_tensor,
weights=weights)
loss = tf.reduce_sum(loss)
return loss
exp_loss = 3.5
......
......@@ -34,7 +34,8 @@ from object_detection.core import standard_fields as fields
class RegionSimilarityCalculator(six.with_metaclass(ABCMeta, object)):
"""Abstract base class for region similarity calculator."""
def compare(self, boxlist1, boxlist2, scope=None, groundtruth_labels=None, predicted_labels=None):
def compare(self, boxlist1, boxlist2, scope=None,
groundtruth_labels=None, predicted_labels=None):
"""Computes matrix of pairwise similarity between BoxLists.
This op (to be overridden) computes a measure of pairwise similarity between
......@@ -59,7 +60,8 @@ class RegionSimilarityCalculator(six.with_metaclass(ABCMeta, object)):
return self._compare(boxlist1, boxlist2, groundtruth_labels, predicted_labels)
@abstractmethod
def _compare(self, boxlist1, boxlist2, groundtruth_labels=None, predicted_labels=None):
def _compare(self, boxlist1, boxlist2,
groundtruth_labels=None, predicted_labels=None):
pass
......
......@@ -102,11 +102,12 @@ class RegionSimilarityCalculatorTest(test_case.TestCase):
boxes1 = box_list.BoxList(corners1)
boxes2 = box_list.BoxList(corners2)
detr_similarity_calculator = region_similarity_calculator.DETRSimilarity()
detr_similarity = detr_similarity_calculator.compare(boxes1, boxes2, None, groundtruth_labels, predicted_labels)
detr_similarity = detr_similarity_calculator.compare(
boxes1, boxes2, None, groundtruth_labels, predicted_labels)
return detr_similarity
exp_output = [[2.0, -2.0/3.0 + 1.0 - 20.0]]
sim_output = self.execute(graph_fn, [])
self.assertAllClose(sim_output, exp_output)
if __name__ == '__main__':
tf.test.main()
\ No newline at end of file
tf.test.main()
......@@ -142,8 +142,8 @@ class TargetAssigner(object):
aware of groundtruth weights. Additionally, `cls_weights` and
`reg_weights` are calculated using groundtruth weights as an added
safety.
class_predictions: a float tensor of shape [N, num_classes] containing class
predictions from the model, to be used in certain matchers.
class_predictions: A tensor with shape [max_num_boxes, d_1, d_2, ..., d_k]
to be used by certain similarity calculators.
Returns:
cls_targets: a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k],
......@@ -200,10 +200,11 @@ class TargetAssigner(object):
with tf.control_dependencies(
[unmatched_shape_assert, labels_and_box_shapes_assert]):
match_quality_matrix = self._similarity_calc.compare(groundtruth_boxes,
anchors,
groundtruth_labels=groundtruth_labels,
predicted_labels=class_predictions)
match_quality_matrix = self._similarity_calc.compare(
groundtruth_boxes,
anchors,
groundtruth_labels=groundtruth_labels,
predicted_labels=class_predictions)
match = self._matcher.match(match_quality_matrix,
valid_rows=tf.greater(groundtruth_weights, 0))
......@@ -481,7 +482,8 @@ def batch_assign(target_assigner,
function (which have shape [num_gt_boxes, d_1, d_2, ..., d_k]).
gt_weights_batch: A list of 1-D tf.float32 tensors of shape
[num_boxes] containing weights for groundtruth boxes.
class_predictions: A
class_predictions: A tensor with shape [batch_size, max_num_boxes,
d_1, d_2, ..., d_k] to be used by certain similarity calculators.
Returns:
batch_cls_targets: a tensor with shape [batch_size, num_anchors,
......@@ -531,8 +533,8 @@ def batch_assign(target_assigner,
class_predictions):
(cls_targets, cls_weights,
reg_targets, reg_weights, match) = target_assigner.assign(
anchors, gt_boxes, gt_class_targets, unmatched_class_label, gt_weights,
class_preds)
anchors, gt_boxes, gt_class_targets, unmatched_class_label,
gt_weights, class_preds)
cls_targets_list.append(cls_targets)
cls_weights_list.append(cls_weights)
reg_targets_list.append(reg_targets)
......
......@@ -1185,7 +1185,7 @@ def giou(boxes1, boxes2):
return shape_utils.static_or_dynamic_map_fn(two_boxes_giou, [boxes1, boxes2])
def cy_cx_h_w_to_ymin_xmin_ymax_xmax_coords(input_tensor):
def cy_cx_h_w_to_ymin_xmin_ymax_xmax(input_tensor):
"""Converts input boxes from center to corner representation."""
reshaped_encodings = tf.reshape(input_tensor, [-1, 4])
ycenter = tf.gather(reshaped_encodings, [0], axis=1)
......
......@@ -1664,7 +1664,7 @@ class TestCoordinateConversion(test_case.TestCase):
def graph_fn():
boxes = tf.constant([[3, 3, 5, 5], [3, 4, 2, 6]], dtype=tf.float32)
converted = ops.cy_cx_h_w_to_ymin_xmin_ymax_xmax_coords(boxes)
converted = ops.cy_cx_h_w_to_ymin_xmin_ymax_xmax(boxes)
self.assertEqual(converted.dtype, tf.float32)
return converted
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment