Commit 1ed7ef39 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

fix naming

parent e350c59c
......@@ -1982,86 +1982,86 @@ class DETRTargetAssigner(object):
batch_reg_weights)
def assign(self,
box_preds,
groundtruth_boxes,
pred_class_batch,
groundtruth_labels,
groundtruth_weights=None):
pred_boxes,
gt_boxes,
pred_classes,
gt_labels,
gt_weights=None):
"""Assign classification and regression targets to each box_pred.
For a given set of box_preds and groundtruth detections, match box_preds
to groundtruth_boxes and assign classification and regression targets to
For a given set of pred_boxes and groundtruth detections, match pred_boxes
to gt_boxes and assign classification and regression targets to
each box_pred as well as weights based on the resulting match (specifying,
e.g., which box_preds should not contribute to training loss).
e.g., which pred_boxes should not contribute to training loss).
box_preds that are not matched to anything are given a classification target
pred_boxes that are not matched to anything are given a classification target
of self._unmatched_cls_target which can be specified via the constructor.
Args:
box_preds: a BoxList representing N box_preds
groundtruth_boxes: a BoxList representing M groundtruth boxes
pred_class_batch: A tensor with shape [max_num_boxes, num_classes]
pred_boxes: a BoxList representing N pred_boxes
gt_boxes: a BoxList representing M groundtruth boxes
pred_classes: A tensor with shape [max_num_boxes, num_classes]
to be used by certain similarity calculators.
groundtruth_labels: a tensor of shape [M, num_classes]
gt_labels: a tensor of shape [M, num_classes]
with labels for each of the ground_truth boxes. The subshape
[num_classes] can be empty (corresponding to scalar inputs). When set
to None, groundtruth_labels assumes a binary problem where all
to None, gt_labels assumes a binary problem where all
ground_truth boxes get a positive label (of 1).
groundtruth_weights: a float tensor of shape [M] indicating the weight to
assign to all box_preds match to a particular groundtruth box. The
gt_weights: a float tensor of shape [M] indicating the weight to
assign to all pred_boxes match to a particular groundtruth box. The
weights must be in [0., 1.]. If None, all weights are set to 1.
Generally no groundtruth boxes with zero weight match to any box_preds
Generally no groundtruth boxes with zero weight match to any pred_boxes
as matchers are aware of groundtruth weights. Additionally,
`cls_weights` and `reg_weights` are calculated using groundtruth
weights as an added safety.
Returns:
cls_targets: a float32 tensor with shape [num_box_preds, num_classes],
where the subshape [num_classes] is compatible with groundtruth_labels
cls_targets: a float32 tensor with shape [num_pred_boxes, num_classes],
where the subshape [num_classes] is compatible with gt_labels
which has shape [num_gt_boxes, num_classes].
cls_weights: a float32 tensor with shape [num_box_preds, num_classes],
cls_weights: a float32 tensor with shape [num_pred_boxes, num_classes],
representing weights for each element in cls_targets.
reg_targets: a float32 tensor with shape [num_box_preds,
reg_targets: a float32 tensor with shape [num_pred_boxes,
box_code_dimension]
reg_weights: a float32 tensor with shape [num_box_preds]
reg_weights: a float32 tensor with shape [num_pred_boxes]
"""
unmatched_class_label = tf.constant(
[1] + [0] * (groundtruth_labels.shape[1] - 1), tf.float32)
[1] + [0] * (gt_labels.shape[1] - 1), tf.float32)
if groundtruth_weights is None:
num_gt_boxes = groundtruth_boxes.num_boxes_static()
if gt_weights is None:
num_gt_boxes = gt_boxes.num_boxes_static()
if not num_gt_boxes:
num_gt_boxes = groundtruth_boxes.num_boxes()
groundtruth_weights = tf.ones([num_gt_boxes], dtype=tf.float32)
num_gt_boxes = gt_boxes.num_boxes()
gt_weights = tf.ones([num_gt_boxes], dtype=tf.float32)
groundtruth_boxes.add_field(fields.BoxListFields.classes,
groundtruth_labels)
box_preds.add_field(fields.BoxListFields.classes, pred_class_batch)
gt_boxes.add_field(fields.BoxListFields.classes,
gt_labels)
pred_boxes.add_field(fields.BoxListFields.classes, pred_classes)
match_quality_matrix = self._similarity_calc.compare(
groundtruth_boxes,
box_preds)
gt_boxes,
pred_boxes)
match = self._matcher.match(match_quality_matrix,
valid_rows=tf.greater(groundtruth_weights, 0))
valid_rows=tf.greater(gt_weights, 0))
matched_gt_boxes = match.gather_based_on_match(
groundtruth_boxes.get(),
gt_boxes.get(),
unmatched_value=tf.zeros(4),
ignored_value=tf.zeros(4))
matched_gt_boxlist = box_list.BoxList(matched_gt_boxes)
ty, tx, th, tw = matched_gt_boxlist.get_center_coordinates_and_sizes()
reg_targets = tf.transpose(tf.stack([ty, tx, th, tw]))
cls_targets = match.gather_based_on_match(
groundtruth_labels,
gt_labels,
unmatched_value=unmatched_class_label,
ignored_value=unmatched_class_label)
reg_weights = match.gather_based_on_match(
groundtruth_weights,
gt_weights,
ignored_value=0.,
unmatched_value=0.)
cls_weights = match.gather_based_on_match(
groundtruth_weights,
gt_weights,
ignored_value=0.,
unmatched_value=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment