Commit 6bccc202 authored by Ronny Votel's avatar Ronny Votel Committed by TF Object Detection Team
Browse files

Plumbing groundtruth instance mask weights through the model codebase.

PiperOrigin-RevId: 377104676
parent 677aaa11
...@@ -101,7 +101,7 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)): ...@@ -101,7 +101,7 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
Args: Args:
field: a string key, options are field: a string key, options are
fields.BoxListFields.{boxes,classes,masks,keypoints, fields.BoxListFields.{boxes,classes,masks,mask_weights,keypoints,
keypoint_visibilities, densepose_*, track_ids, keypoint_visibilities, densepose_*, track_ids,
temporal_offsets, track_match_flags} temporal_offsets, track_match_flags}
fields.InputDataFields.is_annotated. fields.InputDataFields.is_annotated.
...@@ -123,7 +123,7 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)): ...@@ -123,7 +123,7 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
Args: Args:
field: a string key, options are field: a string key, options are
fields.BoxListFields.{boxes,classes,masks,keypoints, fields.BoxListFields.{boxes,classes,masks,mask_weights,keypoints,
keypoint_visibilities, densepose_*, track_ids} or keypoint_visibilities, densepose_*, track_ids} or
fields.InputDataFields.is_annotated. fields.InputDataFields.is_annotated.
...@@ -299,6 +299,7 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)): ...@@ -299,6 +299,7 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
groundtruth_boxes_list, groundtruth_boxes_list,
groundtruth_classes_list, groundtruth_classes_list,
groundtruth_masks_list=None, groundtruth_masks_list=None,
groundtruth_mask_weights_list=None,
groundtruth_keypoints_list=None, groundtruth_keypoints_list=None,
groundtruth_keypoint_visibilities_list=None, groundtruth_keypoint_visibilities_list=None,
groundtruth_dp_num_points_list=None, groundtruth_dp_num_points_list=None,
...@@ -334,6 +335,8 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)): ...@@ -334,6 +335,8 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
masks with values in {0, 1}. If None, no masks are provided. masks with values in {0, 1}. If None, no masks are provided.
Mask resolution `height_in`x`width_in` must agree with the resolution Mask resolution `height_in`x`width_in` must agree with the resolution
of the input image tensor provided to the `preprocess` function. of the input image tensor provided to the `preprocess` function.
groundtruth_mask_weights_list: a list of 1-D tf.float32 tensors of shape
[num_boxes] with weights for each instance mask.
groundtruth_keypoints_list: a list of 3-D tf.float32 tensors of groundtruth_keypoints_list: a list of 3-D tf.float32 tensors of
shape [num_boxes, num_keypoints, 2] containing keypoints. shape [num_boxes, num_keypoints, 2] containing keypoints.
Keypoints are assumed to be provided in normalized coordinates and Keypoints are assumed to be provided in normalized coordinates and
...@@ -399,6 +402,9 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)): ...@@ -399,6 +402,9 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
if groundtruth_masks_list: if groundtruth_masks_list:
self._groundtruth_lists[ self._groundtruth_lists[
fields.BoxListFields.masks] = groundtruth_masks_list fields.BoxListFields.masks] = groundtruth_masks_list
if groundtruth_mask_weights_list:
self._groundtruth_lists[
fields.BoxListFields.mask_weights] = groundtruth_mask_weights_list
if groundtruth_keypoints_list: if groundtruth_keypoints_list:
self._groundtruth_lists[ self._groundtruth_lists[
fields.BoxListFields.keypoints] = groundtruth_keypoints_list fields.BoxListFields.keypoints] = groundtruth_keypoints_list
......
...@@ -210,6 +210,7 @@ class BoxListFields(object): ...@@ -210,6 +210,7 @@ class BoxListFields(object):
weights: sample weights per bounding box. weights: sample weights per bounding box.
objectness: objectness score per bounding box. objectness: objectness score per bounding box.
masks: masks per bounding box. masks: masks per bounding box.
mask_weights: mask weights for each bounding box.
boundaries: boundaries per bounding box. boundaries: boundaries per bounding box.
keypoints: keypoints per bounding box. keypoints: keypoints per bounding box.
keypoint_visibilities: keypoint visibilities per bounding box. keypoint_visibilities: keypoint visibilities per bounding box.
...@@ -230,6 +231,7 @@ class BoxListFields(object): ...@@ -230,6 +231,7 @@ class BoxListFields(object):
confidences = 'confidences' confidences = 'confidences'
objectness = 'objectness' objectness = 'objectness'
masks = 'masks' masks = 'masks'
mask_weights = 'mask_weights'
boundaries = 'boundaries' boundaries = 'boundaries'
keypoints = 'keypoints' keypoints = 'keypoints'
keypoint_visibilities = 'keypoint_visibilities' keypoint_visibilities = 'keypoint_visibilities'
......
...@@ -266,6 +266,7 @@ def unstack_batch(tensor_dict, unpad_groundtruth_tensors=True): ...@@ -266,6 +266,7 @@ def unstack_batch(tensor_dict, unpad_groundtruth_tensors=True):
# dimension. This list has to be kept in sync with InputDataFields in # dimension. This list has to be kept in sync with InputDataFields in
# standard_fields.py. # standard_fields.py.
fields.InputDataFields.groundtruth_instance_masks, fields.InputDataFields.groundtruth_instance_masks,
fields.InputDataFields.groundtruth_instance_mask_weights,
fields.InputDataFields.groundtruth_classes, fields.InputDataFields.groundtruth_classes,
fields.InputDataFields.groundtruth_boxes, fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_keypoints, fields.InputDataFields.groundtruth_keypoints,
...@@ -319,6 +320,10 @@ def provide_groundtruth(model, labels): ...@@ -319,6 +320,10 @@ def provide_groundtruth(model, labels):
if fields.InputDataFields.groundtruth_instance_masks in labels: if fields.InputDataFields.groundtruth_instance_masks in labels:
gt_masks_list = labels[ gt_masks_list = labels[
fields.InputDataFields.groundtruth_instance_masks] fields.InputDataFields.groundtruth_instance_masks]
gt_mask_weights_list = None
if fields.InputDataFields.groundtruth_instance_mask_weights in labels:
gt_mask_weights_list = labels[
fields.InputDataFields.groundtruth_instance_mask_weights]
gt_keypoints_list = None gt_keypoints_list = None
if fields.InputDataFields.groundtruth_keypoints in labels: if fields.InputDataFields.groundtruth_keypoints in labels:
gt_keypoints_list = labels[fields.InputDataFields.groundtruth_keypoints] gt_keypoints_list = labels[fields.InputDataFields.groundtruth_keypoints]
...@@ -383,6 +388,7 @@ def provide_groundtruth(model, labels): ...@@ -383,6 +388,7 @@ def provide_groundtruth(model, labels):
groundtruth_confidences_list=gt_confidences_list, groundtruth_confidences_list=gt_confidences_list,
groundtruth_labeled_classes=gt_labeled_classes, groundtruth_labeled_classes=gt_labeled_classes,
groundtruth_masks_list=gt_masks_list, groundtruth_masks_list=gt_masks_list,
groundtruth_mask_weights_list=gt_mask_weights_list,
groundtruth_keypoints_list=gt_keypoints_list, groundtruth_keypoints_list=gt_keypoints_list,
groundtruth_keypoint_visibilities_list=gt_keypoint_visibilities_list, groundtruth_keypoint_visibilities_list=gt_keypoint_visibilities_list,
groundtruth_dp_num_points_list=gt_dp_num_points_list, groundtruth_dp_num_points_list=gt_dp_num_points_list,
......
...@@ -87,6 +87,8 @@ def _compute_losses_and_predictions_dicts( ...@@ -87,6 +87,8 @@ def _compute_losses_and_predictions_dicts(
labels[fields.InputDataFields.groundtruth_instance_masks] is a labels[fields.InputDataFields.groundtruth_instance_masks] is a
float32 tensor containing only binary values, which represent float32 tensor containing only binary values, which represent
instance masks for objects. instance masks for objects.
labels[fields.InputDataFields.groundtruth_instance_mask_weights] is a
float32 tensor containing weights for the instance masks.
labels[fields.InputDataFields.groundtruth_keypoints] is a labels[fields.InputDataFields.groundtruth_keypoints] is a
float32 tensor containing keypoints for each box. float32 tensor containing keypoints for each box.
labels[fields.InputDataFields.groundtruth_dp_num_points] is an int32 labels[fields.InputDataFields.groundtruth_dp_num_points] is an int32
...@@ -237,6 +239,9 @@ def eager_train_step(detection_model, ...@@ -237,6 +239,9 @@ def eager_train_step(detection_model,
labels[fields.InputDataFields.groundtruth_instance_masks] is a labels[fields.InputDataFields.groundtruth_instance_masks] is a
[batch_size, num_boxes, H, W] float32 tensor containing only binary [batch_size, num_boxes, H, W] float32 tensor containing only binary
values, which represent instance masks for objects. values, which represent instance masks for objects.
labels[fields.InputDataFields.groundtruth_instance_mask_weights] is a
[batch_size, num_boxes] float32 tensor containing weights for the
instance masks.
labels[fields.InputDataFields.groundtruth_keypoints] is a labels[fields.InputDataFields.groundtruth_keypoints] is a
[batch_size, num_boxes, num_keypoints, 2] float32 tensor containing [batch_size, num_boxes, num_keypoints, 2] float32 tensor containing
keypoints for each box. keypoints for each box.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment