Commit 213a9649 authored by Ronny Votel's avatar Ronny Votel Committed by TF Object Detection Team
Browse files

Introducing groundtruth instance mask weights.

PiperOrigin-RevId: 377096964
parent 0b9a2a74
......@@ -1414,6 +1414,7 @@ def _strict_random_crop_image(image,
label_confidences=None,
multiclass_scores=None,
masks=None,
mask_weights=None,
keypoints=None,
keypoint_visibilities=None,
densepose_num_points=None,
......@@ -1451,6 +1452,8 @@ def _strict_random_crop_image(image,
masks: (optional) rank 3 float32 tensor with shape
[num_instances, height, width] containing instance masks. The masks
are of the same height, width as the input `image`.
mask_weights: (optional) rank 1 float32 tensor with shape [num_instances]
with instance masks weights.
keypoints: (optional) rank 3 float32 tensor with shape
[num_instances, num_keypoints, 2]. The keypoints are in y-x
normalized coordinates.
......@@ -1488,7 +1491,7 @@ def _strict_random_crop_image(image,
Boxes are in normalized form.
labels: new labels.
If label_weights, multiclass_scores, masks, keypoints,
If label_weights, multiclass_scores, masks, mask_weights, keypoints,
keypoint_visibilities, densepose_num_points, densepose_part_ids, or
densepose_surface_coords is not None, the function also returns:
label_weights: rank 1 float32 tensor with shape [num_instances].
......@@ -1496,6 +1499,8 @@ def _strict_random_crop_image(image,
[num_instances, num_classes]
masks: rank 3 float32 tensor with shape [num_instances, height, width]
containing instance masks.
mask_weights: rank 1 float32 tensor with shape [num_instances] with mask
weights.
keypoints: rank 3 float32 tensor with shape
[num_instances, num_keypoints, 2]
keypoint_visibilities: rank 2 bool tensor with shape
......@@ -1605,6 +1610,12 @@ def _strict_random_crop_image(image,
0]:im_box_end[0], im_box_begin[1]:im_box_end[1]]
result.append(new_masks)
if mask_weights is not None:
mask_weights_inside_window = tf.gather(mask_weights, inside_window_ids)
mask_weights_completely_inside_window = tf.gather(
mask_weights_inside_window, keep_ids)
result.append(mask_weights_completely_inside_window)
if keypoints is not None:
keypoints_of_boxes_inside_window = tf.gather(keypoints, inside_window_ids)
keypoints_of_boxes_completely_inside_window = tf.gather(
......@@ -1654,6 +1665,7 @@ def random_crop_image(image,
label_confidences=None,
multiclass_scores=None,
masks=None,
mask_weights=None,
keypoints=None,
keypoint_visibilities=None,
densepose_num_points=None,
......@@ -1701,6 +1713,8 @@ def random_crop_image(image,
masks: (optional) rank 3 float32 tensor with shape
[num_instances, height, width] containing instance masks. The masks
are of the same height, width as the input `image`.
mask_weights: (optional) rank 1 float32 tensor with shape [num_instances]
containing weights for each instance mask.
keypoints: (optional) rank 3 float32 tensor with shape
[num_instances, num_keypoints, 2]. The keypoints are in y-x
normalized coordinates.
......@@ -1751,6 +1765,7 @@ def random_crop_image(image,
[num_instances, num_classes]
masks: rank 3 float32 tensor with shape [num_instances, height, width]
containing instance masks.
mask_weights: rank 1 float32 tensor with shape [num_instances].
keypoints: rank 3 float32 tensor with shape
[num_instances, num_keypoints, 2]
keypoint_visibilities: rank 2 bool tensor with shape
......@@ -1771,6 +1786,7 @@ def random_crop_image(image,
label_confidences=label_confidences,
multiclass_scores=multiclass_scores,
masks=masks,
mask_weights=mask_weights,
keypoints=keypoints,
keypoint_visibilities=keypoint_visibilities,
densepose_num_points=densepose_num_points,
......@@ -1803,6 +1819,8 @@ def random_crop_image(image,
outputs.append(multiclass_scores)
if masks is not None:
outputs.append(masks)
if mask_weights is not None:
outputs.append(mask_weights)
if keypoints is not None:
outputs.append(keypoints)
if keypoint_visibilities is not None:
......@@ -4388,6 +4406,7 @@ def get_default_func_arg_map(include_label_weights=True,
include_label_confidences=False,
include_multiclass_scores=False,
include_instance_masks=False,
include_instance_mask_weights=False,
include_keypoints=False,
include_keypoint_visibilities=False,
include_dense_pose=False,
......@@ -4403,6 +4422,8 @@ def get_default_func_arg_map(include_label_weights=True,
multiclass scores, too.
include_instance_masks: If True, preprocessing functions will modify the
instance masks, too.
include_instance_mask_weights: If True, preprocessing functions will modify
the instance mask weights.
include_keypoints: If True, preprocessing functions will modify the
keypoints, too.
include_keypoint_visibilities: If True, preprocessing functions will modify
......@@ -4434,6 +4455,11 @@ def get_default_func_arg_map(include_label_weights=True,
groundtruth_instance_masks = (
fields.InputDataFields.groundtruth_instance_masks)
groundtruth_instance_mask_weights = None
if include_instance_mask_weights:
groundtruth_instance_mask_weights = (
fields.InputDataFields.groundtruth_instance_mask_weights)
groundtruth_keypoints = None
if include_keypoints:
groundtruth_keypoints = fields.InputDataFields.groundtruth_keypoints
......@@ -4503,7 +4529,8 @@ def get_default_func_arg_map(include_label_weights=True,
fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes,
groundtruth_label_weights, groundtruth_label_confidences,
multiclass_scores, groundtruth_instance_masks, groundtruth_keypoints,
multiclass_scores, groundtruth_instance_masks,
groundtruth_instance_mask_weights, groundtruth_keypoints,
groundtruth_keypoint_visibilities, groundtruth_dp_num_points,
groundtruth_dp_part_ids, groundtruth_dp_surface_coords),
random_pad_image:
......
......@@ -1894,6 +1894,37 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
self.assertAllClose(
new_boxes.flatten(), expected_boxes.flatten())
def testStrictRandomCropImageWithMaskWeights(self):
def graph_fn():
image = self.createColorfulTestImage()[0]
boxes = self.createTestBoxes()
labels = self.createTestLabels()
weights = self.createTestGroundtruthWeights()
masks = tf.random_uniform([2, 200, 400], dtype=tf.float32)
mask_weights = tf.constant([1.0, 0.0], dtype=tf.float32)
with mock.patch.object(
tf.image,
'sample_distorted_bounding_box'
) as mock_sample_distorted_bounding_box:
mock_sample_distorted_bounding_box.return_value = (
tf.constant([6, 143, 0], dtype=tf.int32),
tf.constant([190, 237, -1], dtype=tf.int32),
tf.constant([[[0.03, 0.3575, 0.98, 0.95]]], dtype=tf.float32))
results = preprocessor._strict_random_crop_image(
image, boxes, labels, weights, masks=masks,
mask_weights=mask_weights)
return results
(new_image, new_boxes, _, _,
new_masks, new_mask_weights) = self.execute_cpu(graph_fn, [])
expected_boxes = np.array(
[[0.0, 0.0, 0.75789469, 1.0],
[0.23157893, 0.24050637, 0.75789469, 1.0]], dtype=np.float32)
self.assertAllEqual(new_image.shape, [190, 237, 3])
self.assertAllEqual(new_masks.shape, [2, 190, 237])
self.assertAllClose(new_mask_weights, [1.0, 0.0])
self.assertAllClose(
new_boxes.flatten(), expected_boxes.flatten())
def testStrictRandomCropImageWithKeypoints(self):
def graph_fn():
image = self.createColorfulTestImage()[0]
......@@ -1947,6 +1978,7 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
labels = self.createTestLabels()
weights = self.createTestGroundtruthWeights()
masks = tf.random_uniform([2, 200, 400], dtype=tf.float32)
mask_weights = tf.constant([1.0, 0.0], dtype=tf.float32)
tensor_dict = {
fields.InputDataFields.image: image,
......@@ -1954,10 +1986,12 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
fields.InputDataFields.groundtruth_classes: labels,
fields.InputDataFields.groundtruth_weights: weights,
fields.InputDataFields.groundtruth_instance_masks: masks,
fields.InputDataFields.groundtruth_instance_mask_weights:
mask_weights
}
preprocessor_arg_map = preprocessor.get_default_func_arg_map(
include_instance_masks=True)
include_instance_masks=True, include_instance_mask_weights=True)
preprocessing_options = [(preprocessor.random_crop_image, {})]
......@@ -1980,16 +2014,19 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
fields.InputDataFields.groundtruth_classes]
distorted_masks = distorted_tensor_dict[
fields.InputDataFields.groundtruth_instance_masks]
distorted_mask_weights = distorted_tensor_dict[
fields.InputDataFields.groundtruth_instance_mask_weights]
return [distorted_image, distorted_boxes, distorted_labels,
distorted_masks]
distorted_masks, distorted_mask_weights]
(distorted_image_, distorted_boxes_, distorted_labels_,
distorted_masks_) = self.execute_cpu(graph_fn, [])
distorted_masks_, distorted_mask_weights_) = self.execute_cpu(graph_fn, [])
expected_boxes = np.array([
[0.0, 0.0, 0.75789469, 1.0],
[0.23157893, 0.24050637, 0.75789469, 1.0],
], dtype=np.float32)
self.assertAllEqual(distorted_image_.shape, [1, 190, 237, 3])
self.assertAllEqual(distorted_masks_.shape, [2, 190, 237])
self.assertAllClose(distorted_mask_weights_, [1.0, 0.0])
self.assertAllEqual(distorted_labels_, [1, 2])
self.assertAllClose(
distorted_boxes_.flatten(), expected_boxes.flatten())
......
......@@ -64,6 +64,7 @@ class InputDataFields(object):
proposal_boxes: coordinates of object proposal boxes.
proposal_objectness: objectness score of each proposal.
groundtruth_instance_masks: ground truth instance masks.
groundtruth_instance_mask_weights: ground truth instance masks weights.
groundtruth_instance_boundaries: ground truth instance boundaries.
groundtruth_instance_classes: instance mask-level class labels.
groundtruth_keypoints: ground truth keypoints.
......@@ -122,6 +123,7 @@ class InputDataFields(object):
proposal_boxes = 'proposal_boxes'
proposal_objectness = 'proposal_objectness'
groundtruth_instance_masks = 'groundtruth_instance_masks'
groundtruth_instance_mask_weights = 'groundtruth_instance_mask_weights'
groundtruth_instance_boundaries = 'groundtruth_instance_boundaries'
groundtruth_instance_classes = 'groundtruth_instance_classes'
groundtruth_keypoints = 'groundtruth_keypoints'
......
......@@ -373,6 +373,11 @@ class TfExampleDecoder(data_decoder.DataDecoder):
self._decode_png_instance_masks))
else:
raise ValueError('Did not recognize the `instance_mask_type` option.')
self.keys_to_features['image/object/mask/weight'] = (
tf.VarLenFeature(tf.float32))
self.items_to_handlers[
fields.InputDataFields.groundtruth_instance_mask_weights] = (
slim_example_decoder.Tensor('image/object/mask/weight'))
if load_dense_pose:
self.keys_to_features['image/object/densepose/num'] = (
tf.VarLenFeature(tf.int64))
......@@ -491,6 +496,10 @@ class TfExampleDecoder(data_decoder.DataDecoder):
tensor of shape [None, num_keypoints] containing keypoint visibilites.
fields.InputDataFields.groundtruth_instance_masks - 3D float32 tensor of
shape [None, None, None] containing instance masks.
fields.InputDataFields.groundtruth_instance_mask_weights - 1D float32
tensor of shape [None] containing weights. These are typically values
in {0.0, 1.0} which indicate whether to consider the mask related to an
object.
fields.InputDataFields.groundtruth_image_classes - 1D int64 of shape
[None] containing classes for the boxes.
fields.InputDataFields.multiclass_scores - 1D float32 tensor of shape
......@@ -531,6 +540,21 @@ class TfExampleDecoder(data_decoder.DataDecoder):
0), lambda: tensor_dict[fields.InputDataFields.groundtruth_weights],
default_groundtruth_weights)
if fields.InputDataFields.groundtruth_instance_masks in tensor_dict:
gt_instance_masks = tensor_dict[
fields.InputDataFields.groundtruth_instance_masks]
num_gt_instance_masks = tf.shape(gt_instance_masks)[0]
gt_instance_mask_weights = tensor_dict[
fields.InputDataFields.groundtruth_instance_mask_weights]
num_gt_instance_mask_weights = tf.shape(gt_instance_mask_weights)[0]
def default_groundtruth_instance_mask_weights():
return tf.ones([num_gt_instance_masks], dtype=tf.float32)
tensor_dict[fields.InputDataFields.groundtruth_instance_mask_weights] = (
tf.cond(tf.greater(num_gt_instance_mask_weights, 0),
lambda: gt_instance_mask_weights,
default_groundtruth_instance_mask_weights))
if fields.InputDataFields.groundtruth_keypoints in tensor_dict:
# Set all keypoints that are not labeled to NaN.
gt_kpt_fld = fields.InputDataFields.groundtruth_keypoints
......
......@@ -1225,6 +1225,9 @@ class TfExampleDecoderTest(test_case.TestCase):
self.assertAllEqual(
instance_masks.astype(np.float32),
tensor_dict[fields.InputDataFields.groundtruth_instance_masks])
self.assertAllEqual(
tensor_dict[fields.InputDataFields.groundtruth_instance_mask_weights],
[1, 1, 1, 1])
self.assertAllEqual(object_classes,
tensor_dict[fields.InputDataFields.groundtruth_classes])
......@@ -1272,6 +1275,71 @@ class TfExampleDecoderTest(test_case.TestCase):
self.assertNotIn(fields.InputDataFields.groundtruth_instance_masks,
tensor_dict)
def testDecodeInstanceSegmentationWithWeights(self):
num_instances = 4
image_height = 5
image_width = 3
# Randomly generate image.
image_tensor = np.random.randint(
256, size=(image_height, image_width, 3)).astype(np.uint8)
encoded_jpeg, _ = self._create_encoded_and_decoded_data(
image_tensor, 'jpeg')
# Randomly generate instance segmentation masks.
instance_masks = (
np.random.randint(2, size=(num_instances, image_height,
image_width)).astype(np.float32))
instance_masks_flattened = np.reshape(instance_masks, [-1])
instance_mask_weights = np.array([1, 1, 0, 1], dtype=np.float32)
# Randomly generate class labels for each instance.
object_classes = np.random.randint(
100, size=(num_instances)).astype(np.int64)
def graph_fn():
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
dataset_util.bytes_feature(encoded_jpeg),
'image/format':
dataset_util.bytes_feature(six.b('jpeg')),
'image/height':
dataset_util.int64_feature(image_height),
'image/width':
dataset_util.int64_feature(image_width),
'image/object/mask':
dataset_util.float_list_feature(instance_masks_flattened),
'image/object/mask/weight':
dataset_util.float_list_feature(instance_mask_weights),
'image/object/class/label':
dataset_util.int64_list_feature(object_classes)
})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder(
load_instance_masks=True)
output = example_decoder.decode(tf.convert_to_tensor(example))
self.assertAllEqual(
(output[fields.InputDataFields.groundtruth_instance_masks].get_shape(
).as_list()), [4, 5, 3])
self.assertAllEqual(
output[fields.InputDataFields.groundtruth_instance_mask_weights],
[1, 1, 0, 1])
self.assertAllEqual((output[
fields.InputDataFields.groundtruth_classes].get_shape().as_list()),
[4])
return output
tensor_dict = self.execute_cpu(graph_fn, [])
self.assertAllEqual(
instance_masks.astype(np.float32),
tensor_dict[fields.InputDataFields.groundtruth_instance_masks])
self.assertAllEqual(object_classes,
tensor_dict[fields.InputDataFields.groundtruth_classes])
def testDecodeImageLabels(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg, _ = self._create_encoded_and_decoded_data(
......
......@@ -479,6 +479,7 @@ def pad_input_data_to_static_shapes(tensor_dict,
input_fields.groundtruth_instance_masks: [
max_num_boxes, height, width
],
input_fields.groundtruth_instance_mask_weights: [max_num_boxes],
input_fields.groundtruth_is_crowd: [max_num_boxes],
input_fields.groundtruth_group_of: [max_num_boxes],
input_fields.groundtruth_area: [max_num_boxes],
......@@ -601,6 +602,8 @@ def augment_input_data(tensor_dict, data_augmentation_options):
include_instance_masks = (fields.InputDataFields.groundtruth_instance_masks
in tensor_dict)
include_instance_mask_weights = (
fields.InputDataFields.groundtruth_instance_mask_weights in tensor_dict)
include_keypoints = (fields.InputDataFields.groundtruth_keypoints
in tensor_dict)
include_keypoint_visibilities = (
......@@ -624,6 +627,7 @@ def augment_input_data(tensor_dict, data_augmentation_options):
include_label_confidences=include_label_confidences,
include_multiclass_scores=include_multiclass_scores,
include_instance_masks=include_instance_masks,
include_instance_mask_weights=include_instance_mask_weights,
include_keypoints=include_keypoints,
include_keypoint_visibilities=include_keypoint_visibilities,
include_dense_pose=include_dense_pose,
......@@ -652,6 +656,7 @@ def _get_labels_dict(input_dict):
fields.InputDataFields.groundtruth_keypoint_depths,
fields.InputDataFields.groundtruth_keypoint_depth_weights,
fields.InputDataFields.groundtruth_instance_masks,
fields.InputDataFields.groundtruth_instance_mask_weights,
fields.InputDataFields.groundtruth_area,
fields.InputDataFields.groundtruth_is_crowd,
fields.InputDataFields.groundtruth_group_of,
......@@ -804,6 +809,9 @@ def train_input(train_config, train_input_config,
labels[fields.InputDataFields.groundtruth_instance_masks] is a
[batch_size, num_boxes, H, W] float32 tensor containing only binary
values, which represent instance masks for objects.
labels[fields.InputDataFields.groundtruth_instance_mask_weights] is a
[batch_size, num_boxes] float32 tensor containing groundtruth weights
for each instance mask.
labels[fields.InputDataFields.groundtruth_keypoints] is a
[batch_size, num_boxes, num_keypoints, 2] float32 tensor containing
keypoints for each box.
......@@ -961,6 +969,9 @@ def eval_input(eval_config, eval_input_config, model_config,
labels[fields.InputDataFields.groundtruth_instance_masks] is a
[1, num_boxes, H, W] float32 tensor containing only binary values,
which represent instance masks for objects.
labels[fields.InputDataFields.groundtruth_instance_mask_weights] is a
[1, num_boxes] float32 tensor containing groundtruth weights for each
instance mask.
labels[fields.InputDataFields.groundtruth_weights] is a
[batch_size, num_boxes, num_keypoints] float32 tensor containing
groundtruth weights for the keypoints.
......
......@@ -795,15 +795,20 @@ class DataAugmentationFnTest(test_case.TestCase):
fields.InputDataFields.image:
tf.constant(np.random.rand(10, 10, 3).astype(np.float32)),
fields.InputDataFields.groundtruth_instance_masks:
tf.constant(np.zeros([2, 10, 10], np.uint8))
tf.constant(np.zeros([2, 10, 10], np.uint8)),
fields.InputDataFields.groundtruth_instance_mask_weights:
tf.constant([1.0, 0.0], np.float32)
}
augmented_tensor_dict = data_augmentation_fn(tensor_dict=tensor_dict)
return (augmented_tensor_dict[fields.InputDataFields.image],
augmented_tensor_dict[fields.InputDataFields.
groundtruth_instance_masks])
image, masks = self.execute_cpu(graph_fn, [])
groundtruth_instance_masks],
augmented_tensor_dict[fields.InputDataFields.
groundtruth_instance_mask_weights])
image, masks, mask_weights = self.execute_cpu(graph_fn, [])
self.assertAllEqual(image.shape, [20, 20, 3])
self.assertAllEqual(masks.shape, [2, 20, 20])
self.assertAllClose(mask_weights, [1.0, 0.0])
def test_include_keypoints_in_data_augmentation(self):
data_augmentation_options = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment