Commit 1ed7ef39 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

fix naming

parent e350c59c
...@@ -1982,86 +1982,86 @@ class DETRTargetAssigner(object): ...@@ -1982,86 +1982,86 @@ class DETRTargetAssigner(object):
batch_reg_weights) batch_reg_weights)
def assign(self, def assign(self,
box_preds, pred_boxes,
groundtruth_boxes, gt_boxes,
pred_class_batch, pred_classes,
groundtruth_labels, gt_labels,
groundtruth_weights=None): gt_weights=None):
"""Assign classification and regression targets to each box_pred. """Assign classification and regression targets to each box_pred.
For a given set of box_preds and groundtruth detections, match box_preds For a given set of pred_boxes and groundtruth detections, match pred_boxes
to groundtruth_boxes and assign classification and regression targets to to gt_boxes and assign classification and regression targets to
each box_pred as well as weights based on the resulting match (specifying, each box_pred as well as weights based on the resulting match (specifying,
e.g., which box_preds should not contribute to training loss). e.g., which pred_boxes should not contribute to training loss).
box_preds that are not matched to anything are given a classification target pred_boxes that are not matched to anything are given a classification target
of self._unmatched_cls_target which can be specified via the constructor. of self._unmatched_cls_target which can be specified via the constructor.
Args: Args:
box_preds: a BoxList representing N box_preds pred_boxes: a BoxList representing N pred_boxes
groundtruth_boxes: a BoxList representing M groundtruth boxes gt_boxes: a BoxList representing M groundtruth boxes
pred_class_batch: A tensor with shape [max_num_boxes, num_classes] pred_classes: A tensor with shape [max_num_boxes, num_classes]
to be used by certain similarity calculators. to be used by certain similarity calculators.
groundtruth_labels: a tensor of shape [M, num_classes] gt_labels: a tensor of shape [M, num_classes]
with labels for each of the ground_truth boxes. The subshape with labels for each of the ground_truth boxes. The subshape
[num_classes] can be empty (corresponding to scalar inputs). When set [num_classes] can be empty (corresponding to scalar inputs). When set
to None, groundtruth_labels assumes a binary problem where all to None, gt_labels assumes a binary problem where all
ground_truth boxes get a positive label (of 1). ground_truth boxes get a positive label (of 1).
groundtruth_weights: a float tensor of shape [M] indicating the weight to gt_weights: a float tensor of shape [M] indicating the weight to
assign to all box_preds match to a particular groundtruth box. The assign to all pred_boxes match to a particular groundtruth box. The
weights must be in [0., 1.]. If None, all weights are set to 1. weights must be in [0., 1.]. If None, all weights are set to 1.
Generally no groundtruth boxes with zero weight match to any box_preds Generally no groundtruth boxes with zero weight match to any pred_boxes
as matchers are aware of groundtruth weights. Additionally, as matchers are aware of groundtruth weights. Additionally,
`cls_weights` and `reg_weights` are calculated using groundtruth `cls_weights` and `reg_weights` are calculated using groundtruth
weights as an added safety. weights as an added safety.
Returns: Returns:
cls_targets: a float32 tensor with shape [num_box_preds, num_classes], cls_targets: a float32 tensor with shape [num_pred_boxes, num_classes],
where the subshape [num_classes] is compatible with groundtruth_labels where the subshape [num_classes] is compatible with gt_labels
which has shape [num_gt_boxes, num_classes]. which has shape [num_gt_boxes, num_classes].
cls_weights: a float32 tensor with shape [num_box_preds, num_classes], cls_weights: a float32 tensor with shape [num_pred_boxes, num_classes],
representing weights for each element in cls_targets. representing weights for each element in cls_targets.
reg_targets: a float32 tensor with shape [num_box_preds, reg_targets: a float32 tensor with shape [num_pred_boxes,
box_code_dimension] box_code_dimension]
reg_weights: a float32 tensor with shape [num_box_preds] reg_weights: a float32 tensor with shape [num_pred_boxes]
""" """
unmatched_class_label = tf.constant( unmatched_class_label = tf.constant(
[1] + [0] * (groundtruth_labels.shape[1] - 1), tf.float32) [1] + [0] * (gt_labels.shape[1] - 1), tf.float32)
if groundtruth_weights is None: if gt_weights is None:
num_gt_boxes = groundtruth_boxes.num_boxes_static() num_gt_boxes = gt_boxes.num_boxes_static()
if not num_gt_boxes: if not num_gt_boxes:
num_gt_boxes = groundtruth_boxes.num_boxes() num_gt_boxes = gt_boxes.num_boxes()
groundtruth_weights = tf.ones([num_gt_boxes], dtype=tf.float32) gt_weights = tf.ones([num_gt_boxes], dtype=tf.float32)
groundtruth_boxes.add_field(fields.BoxListFields.classes, gt_boxes.add_field(fields.BoxListFields.classes,
groundtruth_labels) gt_labels)
box_preds.add_field(fields.BoxListFields.classes, pred_class_batch) pred_boxes.add_field(fields.BoxListFields.classes, pred_classes)
match_quality_matrix = self._similarity_calc.compare( match_quality_matrix = self._similarity_calc.compare(
groundtruth_boxes, gt_boxes,
box_preds) pred_boxes)
match = self._matcher.match(match_quality_matrix, match = self._matcher.match(match_quality_matrix,
valid_rows=tf.greater(groundtruth_weights, 0)) valid_rows=tf.greater(gt_weights, 0))
matched_gt_boxes = match.gather_based_on_match( matched_gt_boxes = match.gather_based_on_match(
groundtruth_boxes.get(), gt_boxes.get(),
unmatched_value=tf.zeros(4), unmatched_value=tf.zeros(4),
ignored_value=tf.zeros(4)) ignored_value=tf.zeros(4))
matched_gt_boxlist = box_list.BoxList(matched_gt_boxes) matched_gt_boxlist = box_list.BoxList(matched_gt_boxes)
ty, tx, th, tw = matched_gt_boxlist.get_center_coordinates_and_sizes() ty, tx, th, tw = matched_gt_boxlist.get_center_coordinates_and_sizes()
reg_targets = tf.transpose(tf.stack([ty, tx, th, tw])) reg_targets = tf.transpose(tf.stack([ty, tx, th, tw]))
cls_targets = match.gather_based_on_match( cls_targets = match.gather_based_on_match(
groundtruth_labels, gt_labels,
unmatched_value=unmatched_class_label, unmatched_value=unmatched_class_label,
ignored_value=unmatched_class_label) ignored_value=unmatched_class_label)
reg_weights = match.gather_based_on_match( reg_weights = match.gather_based_on_match(
groundtruth_weights, gt_weights,
ignored_value=0., ignored_value=0.,
unmatched_value=0.) unmatched_value=0.)
cls_weights = match.gather_based_on_match( cls_weights = match.gather_based_on_match(
groundtruth_weights, gt_weights,
ignored_value=0., ignored_value=0.,
unmatched_value=1) unmatched_value=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment