Commit 980d176a authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

fix

parent 7723b206
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Faster RCNN box coder.
Faster RCNN box coder follows the coding schema described below:
ty = (y - ya) / ha
tx = (x - xa) / wa
th = log(h / ha)
tw = log(w / wa)
where x, y, w, h denote the box's center coordinates, width and height
respectively. Similarly, xa, ya, wa, ha denote the anchor's center
coordinates, width and height. tx, ty, tw and th denote the anchor-encoded
center, width and height respectively.
See http://arxiv.org/abs/1506.01497 for details.
"""
import tensorflow.compat.v1 as tf
from object_detection.core import box_coder
from object_detection.core import box_list
EPSILON = 1e-8
class DETRBoxCoder(box_coder.BoxCoder):
"""Faster RCNN box coder."""
def __init__(self, scale_factors=None):
"""Constructor for FasterRcnnBoxCoder.
Args:
scale_factors: List of 4 positive scalars to scale ty, tx, th and tw.
If set to None, does not perform scaling. For Faster RCNN,
the open-source implementation recommends using [10.0, 10.0, 5.0, 5.0].
"""
if None:
assert len(scale_factors) == 4
for scalar in scale_factors:
assert scalar > 0
self._scale_factors = scale_factors
@property
def code_size(self):
return 4
def _encode(self, boxes, anchors):
"""Encode a box collection with respect to anchor collection.
Args:
boxes: BoxList holding N boxes to be encoded.
anchors: BoxList of anchors.
Returns:
a tensor representing N anchor-encoded boxes of the format
[ty, tx, th, tw].
"""
# Convert anchors to the center coordinate representation.
ycenter, xcenter, h, w = boxes.get_center_coordinates_and_sizes()
# Avoid NaN in division and log below.
h += EPSILON
w += EPSILON
tx = xcenter
ty = ycenter
tw = w #tf.log(w)
th = h #tf.log(h)
return tf.transpose(tf.stack([ty, tx, th, tw]))
def _decode(self, rel_codes, anchors):
"""Decode relative codes to boxes.
Args:
rel_codes: a tensor representing N anchor-encoded boxes.
anchors: BoxList of anchors.
Returns:
boxes: BoxList holding N bounding boxes.
"""
ty, tx, th, tw = tf.unstack(tf.transpose(rel_codes))
w = tw
h = th
ycenter = ty
xcenter = tx
ymin = ycenter - h / 2.
xmin = xcenter - w / 2.
ymax = ycenter + h / 2.
xmax = xcenter + w / 2.
return box_list.BoxList(tf.transpose(tf.stack([ymin, xmin, ymax, xmax])))
...@@ -446,7 +446,7 @@ def create_target_assigner(reference, stage=None, ...@@ -446,7 +446,7 @@ def create_target_assigner(reference, stage=None,
elif reference == 'DETR': elif reference == 'DETR':
similarity_calc = sim_calc.DETRSimilarity() similarity_calc = sim_calc.DETRSimilarity()
matcher = hungarian_matcher.HungarianBipartiteMatcher() matcher = hungarian_matcher.HungarianBipartiteMatcher()
box_coder_instance = None box_coder_instance = detr_box_coder.DETRBoxCoder()
else: else:
raise ValueError('No valid combination of reference and stage.') raise ValueError('No valid combination of reference and stage.')
...@@ -481,6 +481,7 @@ def batch_assign(target_assigner, ...@@ -481,6 +481,7 @@ def batch_assign(target_assigner,
function (which have shape [num_gt_boxes, d_1, d_2, ..., d_k]). function (which have shape [num_gt_boxes, d_1, d_2, ..., d_k]).
gt_weights_batch: A list of 1-D tf.float32 tensors of shape gt_weights_batch: A list of 1-D tf.float32 tensors of shape
[num_boxes] containing weights for groundtruth boxes. [num_boxes] containing weights for groundtruth boxes.
class_predictions: A
Returns: Returns:
batch_cls_targets: a tensor with shape [batch_size, num_anchors, batch_cls_targets: a tensor with shape [batch_size, num_anchors,
...@@ -521,7 +522,10 @@ def batch_assign(target_assigner, ...@@ -521,7 +522,10 @@ def batch_assign(target_assigner,
match_list = [] match_list = []
if gt_weights_batch is None: if gt_weights_batch is None:
gt_weights_batch = [None] * len(gt_class_targets_batch) gt_weights_batch = [None] * len(gt_class_targets_batch)
class_predictions = tf.unstack(class_predictions) if class_predictions:
class_predictions = tf.unstack(class_predictions)
else:
class_predictions = [None] * len(gt_class_targets_batch)
for anchors, gt_boxes, gt_class_targets, gt_weights, class_preds in zip( for anchors, gt_boxes, gt_class_targets, gt_weights, class_preds in zip(
anchors_batch, gt_box_batch, gt_class_targets_batch, gt_weights_batch, anchors_batch, gt_box_batch, gt_class_targets_batch, gt_weights_batch,
class_predictions): class_predictions):
......
...@@ -19,6 +19,7 @@ import tensorflow.compat.v1 as tf ...@@ -19,6 +19,7 @@ import tensorflow.compat.v1 as tf
from object_detection.box_coders import keypoint_box_coder from object_detection.box_coders import keypoint_box_coder
from object_detection.box_coders import mean_stddev_box_coder from object_detection.box_coders import mean_stddev_box_coder
from object_detection.box_coders import detr_box_coder
from object_detection.core import box_list from object_detection.core import box_list
from object_detection.core import region_similarity_calculator from object_detection.core import region_similarity_calculator
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
...@@ -1927,7 +1928,7 @@ class CenterNetMaskTargetAssignerTest(test_case.TestCase): ...@@ -1927,7 +1928,7 @@ class CenterNetMaskTargetAssignerTest(test_case.TestCase):
groundtruth_labels, predicted_labels): groundtruth_labels, predicted_labels):
similarity_calc = region_similarity_calculator.DETRSimilarity() similarity_calc = region_similarity_calculator.DETRSimilarity()
matcher = hungarian_matcher.HungarianBipartiteMatcher() matcher = hungarian_matcher.HungarianBipartiteMatcher()
box_coder = box box_coder = detr_box_coder.DETRBoxCoder()
target_assigner = targetassigner.TargetAssigner( target_assigner = targetassigner.TargetAssigner(
similarity_calc, matcher, box_coder) similarity_calc, matcher, box_coder)
anchors_boxlist = box_list.BoxList(anchor_means) anchors_boxlist = box_list.BoxList(anchor_means)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment