Commit fc585957 authored by Vivek Rathod's avatar Vivek Rathod Committed by TF Object Detection Team
Browse files

Support Partial labels with Faster R-CNN.

PiperOrigin-RevId: 325536056
parent 284eacdf
......@@ -2412,7 +2412,15 @@ class FasterRCNNMetaArch(model.DetectionModel):
unmatched_class_label=tf.constant(
[1] + self._num_classes * [0], dtype=tf.float32),
gt_weights_batch=groundtruth_weights_list)
if self.groundtruth_has_field(
fields.InputDataFields.groundtruth_labeled_classes):
gt_labeled_classes = self.groundtruth_lists(
fields.InputDataFields.groundtruth_labeled_classes)
gt_labeled_classes = tf.pad(
gt_labeled_classes, [[0, 0], [1, 0]],
mode='CONSTANT',
constant_values=1)
batch_cls_weights *= tf.expand_dims(gt_labeled_classes, 1)
class_predictions_with_background = tf.reshape(
class_predictions_with_background,
[batch_size, self.max_num_proposals, -1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment