Commit 6bccc202 authored by Ronny Votel's avatar Ronny Votel Committed by TF Object Detection Team
Browse files

Plumbing groundtruth instance mask weights through the model codebase.

PiperOrigin-RevId: 377104676
parent 677aaa11
......@@ -101,7 +101,7 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
Args:
field: a string key, options are
fields.BoxListFields.{boxes,classes,masks,keypoints,
fields.BoxListFields.{boxes,classes,masks,mask_weights,keypoints,
keypoint_visibilities, densepose_*, track_ids,
temporal_offsets, track_match_flags}
fields.InputDataFields.is_annotated.
......@@ -123,7 +123,7 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
Args:
field: a string key, options are
fields.BoxListFields.{boxes,classes,masks,keypoints,
fields.BoxListFields.{boxes,classes,masks,mask_weights,keypoints,
keypoint_visibilities, densepose_*, track_ids} or
fields.InputDataFields.is_annotated.
......@@ -299,6 +299,7 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
groundtruth_boxes_list,
groundtruth_classes_list,
groundtruth_masks_list=None,
groundtruth_mask_weights_list=None,
groundtruth_keypoints_list=None,
groundtruth_keypoint_visibilities_list=None,
groundtruth_dp_num_points_list=None,
......@@ -334,6 +335,8 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
masks with values in {0, 1}. If None, no masks are provided.
Mask resolution `height_in`x`width_in` must agree with the resolution
of the input image tensor provided to the `preprocess` function.
groundtruth_mask_weights_list: a list of 1-D tf.float32 tensors of shape
[num_boxes] with weights for each instance mask.
groundtruth_keypoints_list: a list of 3-D tf.float32 tensors of
shape [num_boxes, num_keypoints, 2] containing keypoints.
Keypoints are assumed to be provided in normalized coordinates and
......@@ -399,6 +402,9 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
if groundtruth_masks_list:
self._groundtruth_lists[
fields.BoxListFields.masks] = groundtruth_masks_list
if groundtruth_mask_weights_list:
self._groundtruth_lists[
fields.BoxListFields.mask_weights] = groundtruth_mask_weights_list
if groundtruth_keypoints_list:
self._groundtruth_lists[
fields.BoxListFields.keypoints] = groundtruth_keypoints_list
......
......@@ -210,6 +210,7 @@ class BoxListFields(object):
weights: sample weights per bounding box.
objectness: objectness score per bounding box.
masks: masks per bounding box.
mask_weights: mask weights for each bounding box.
boundaries: boundaries per bounding box.
keypoints: keypoints per bounding box.
keypoint_visibilities: keypoint visibilities per bounding box.
......@@ -230,6 +231,7 @@ class BoxListFields(object):
confidences = 'confidences'
objectness = 'objectness'
masks = 'masks'
mask_weights = 'mask_weights'
boundaries = 'boundaries'
keypoints = 'keypoints'
keypoint_visibilities = 'keypoint_visibilities'
......
......@@ -266,6 +266,7 @@ def unstack_batch(tensor_dict, unpad_groundtruth_tensors=True):
# dimension. This list has to be kept in sync with InputDataFields in
# standard_fields.py.
fields.InputDataFields.groundtruth_instance_masks,
fields.InputDataFields.groundtruth_instance_mask_weights,
fields.InputDataFields.groundtruth_classes,
fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_keypoints,
......@@ -319,6 +320,10 @@ def provide_groundtruth(model, labels):
if fields.InputDataFields.groundtruth_instance_masks in labels:
gt_masks_list = labels[
fields.InputDataFields.groundtruth_instance_masks]
gt_mask_weights_list = None
if fields.InputDataFields.groundtruth_instance_mask_weights in labels:
gt_mask_weights_list = labels[
fields.InputDataFields.groundtruth_instance_mask_weights]
gt_keypoints_list = None
if fields.InputDataFields.groundtruth_keypoints in labels:
gt_keypoints_list = labels[fields.InputDataFields.groundtruth_keypoints]
......@@ -383,6 +388,7 @@ def provide_groundtruth(model, labels):
groundtruth_confidences_list=gt_confidences_list,
groundtruth_labeled_classes=gt_labeled_classes,
groundtruth_masks_list=gt_masks_list,
groundtruth_mask_weights_list=gt_mask_weights_list,
groundtruth_keypoints_list=gt_keypoints_list,
groundtruth_keypoint_visibilities_list=gt_keypoint_visibilities_list,
groundtruth_dp_num_points_list=gt_dp_num_points_list,
......
......@@ -87,6 +87,8 @@ def _compute_losses_and_predictions_dicts(
labels[fields.InputDataFields.groundtruth_instance_masks] is a
float32 tensor containing only binary values, which represent
instance masks for objects.
labels[fields.InputDataFields.groundtruth_instance_mask_weights] is a
float32 tensor containing weights for the instance masks.
labels[fields.InputDataFields.groundtruth_keypoints] is a
float32 tensor containing keypoints for each box.
labels[fields.InputDataFields.groundtruth_dp_num_points] is an int32
......@@ -237,6 +239,9 @@ def eager_train_step(detection_model,
labels[fields.InputDataFields.groundtruth_instance_masks] is a
[batch_size, num_boxes, H, W] float32 tensor containing only binary
values, which represent instance masks for objects.
labels[fields.InputDataFields.groundtruth_instance_mask_weights] is a
[batch_size, num_boxes] float32 tensor containing weights for the
instance masks.
labels[fields.InputDataFields.groundtruth_keypoints] is a
[batch_size, num_boxes, num_keypoints, 2] float32 tensor containing
keypoints for each box.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment