Commit 213a9649 authored by Ronny Votel's avatar Ronny Votel Committed by TF Object Detection Team
Browse files

Introducing groundtruth instance mask weights.

PiperOrigin-RevId: 377096964
parent 0b9a2a74
...@@ -1414,6 +1414,7 @@ def _strict_random_crop_image(image, ...@@ -1414,6 +1414,7 @@ def _strict_random_crop_image(image,
label_confidences=None, label_confidences=None,
multiclass_scores=None, multiclass_scores=None,
masks=None, masks=None,
mask_weights=None,
keypoints=None, keypoints=None,
keypoint_visibilities=None, keypoint_visibilities=None,
densepose_num_points=None, densepose_num_points=None,
...@@ -1451,6 +1452,8 @@ def _strict_random_crop_image(image, ...@@ -1451,6 +1452,8 @@ def _strict_random_crop_image(image,
masks: (optional) rank 3 float32 tensor with shape masks: (optional) rank 3 float32 tensor with shape
[num_instances, height, width] containing instance masks. The masks [num_instances, height, width] containing instance masks. The masks
are of the same height, width as the input `image`. are of the same height, width as the input `image`.
mask_weights: (optional) rank 1 float32 tensor with shape [num_instances]
with instance masks weights.
keypoints: (optional) rank 3 float32 tensor with shape keypoints: (optional) rank 3 float32 tensor with shape
[num_instances, num_keypoints, 2]. The keypoints are in y-x [num_instances, num_keypoints, 2]. The keypoints are in y-x
normalized coordinates. normalized coordinates.
...@@ -1488,7 +1491,7 @@ def _strict_random_crop_image(image, ...@@ -1488,7 +1491,7 @@ def _strict_random_crop_image(image,
Boxes are in normalized form. Boxes are in normalized form.
labels: new labels. labels: new labels.
If label_weights, multiclass_scores, masks, keypoints, If label_weights, multiclass_scores, masks, mask_weights, keypoints,
keypoint_visibilities, densepose_num_points, densepose_part_ids, or keypoint_visibilities, densepose_num_points, densepose_part_ids, or
densepose_surface_coords is not None, the function also returns: densepose_surface_coords is not None, the function also returns:
label_weights: rank 1 float32 tensor with shape [num_instances]. label_weights: rank 1 float32 tensor with shape [num_instances].
...@@ -1496,6 +1499,8 @@ def _strict_random_crop_image(image, ...@@ -1496,6 +1499,8 @@ def _strict_random_crop_image(image,
[num_instances, num_classes] [num_instances, num_classes]
masks: rank 3 float32 tensor with shape [num_instances, height, width] masks: rank 3 float32 tensor with shape [num_instances, height, width]
containing instance masks. containing instance masks.
mask_weights: rank 1 float32 tensor with shape [num_instances] with mask
weights.
keypoints: rank 3 float32 tensor with shape keypoints: rank 3 float32 tensor with shape
[num_instances, num_keypoints, 2] [num_instances, num_keypoints, 2]
keypoint_visibilities: rank 2 bool tensor with shape keypoint_visibilities: rank 2 bool tensor with shape
...@@ -1605,6 +1610,12 @@ def _strict_random_crop_image(image, ...@@ -1605,6 +1610,12 @@ def _strict_random_crop_image(image,
0]:im_box_end[0], im_box_begin[1]:im_box_end[1]] 0]:im_box_end[0], im_box_begin[1]:im_box_end[1]]
result.append(new_masks) result.append(new_masks)
if mask_weights is not None:
mask_weights_inside_window = tf.gather(mask_weights, inside_window_ids)
mask_weights_completely_inside_window = tf.gather(
mask_weights_inside_window, keep_ids)
result.append(mask_weights_completely_inside_window)
if keypoints is not None: if keypoints is not None:
keypoints_of_boxes_inside_window = tf.gather(keypoints, inside_window_ids) keypoints_of_boxes_inside_window = tf.gather(keypoints, inside_window_ids)
keypoints_of_boxes_completely_inside_window = tf.gather( keypoints_of_boxes_completely_inside_window = tf.gather(
...@@ -1654,6 +1665,7 @@ def random_crop_image(image, ...@@ -1654,6 +1665,7 @@ def random_crop_image(image,
label_confidences=None, label_confidences=None,
multiclass_scores=None, multiclass_scores=None,
masks=None, masks=None,
mask_weights=None,
keypoints=None, keypoints=None,
keypoint_visibilities=None, keypoint_visibilities=None,
densepose_num_points=None, densepose_num_points=None,
...@@ -1701,6 +1713,8 @@ def random_crop_image(image, ...@@ -1701,6 +1713,8 @@ def random_crop_image(image,
masks: (optional) rank 3 float32 tensor with shape masks: (optional) rank 3 float32 tensor with shape
[num_instances, height, width] containing instance masks. The masks [num_instances, height, width] containing instance masks. The masks
are of the same height, width as the input `image`. are of the same height, width as the input `image`.
mask_weights: (optional) rank 1 float32 tensor with shape [num_instances]
containing weights for each instance mask.
keypoints: (optional) rank 3 float32 tensor with shape keypoints: (optional) rank 3 float32 tensor with shape
[num_instances, num_keypoints, 2]. The keypoints are in y-x [num_instances, num_keypoints, 2]. The keypoints are in y-x
normalized coordinates. normalized coordinates.
...@@ -1751,6 +1765,7 @@ def random_crop_image(image, ...@@ -1751,6 +1765,7 @@ def random_crop_image(image,
[num_instances, num_classes] [num_instances, num_classes]
masks: rank 3 float32 tensor with shape [num_instances, height, width] masks: rank 3 float32 tensor with shape [num_instances, height, width]
containing instance masks. containing instance masks.
mask_weights: rank 1 float32 tensor with shape [num_instances].
keypoints: rank 3 float32 tensor with shape keypoints: rank 3 float32 tensor with shape
[num_instances, num_keypoints, 2] [num_instances, num_keypoints, 2]
keypoint_visibilities: rank 2 bool tensor with shape keypoint_visibilities: rank 2 bool tensor with shape
...@@ -1771,6 +1786,7 @@ def random_crop_image(image, ...@@ -1771,6 +1786,7 @@ def random_crop_image(image,
label_confidences=label_confidences, label_confidences=label_confidences,
multiclass_scores=multiclass_scores, multiclass_scores=multiclass_scores,
masks=masks, masks=masks,
mask_weights=mask_weights,
keypoints=keypoints, keypoints=keypoints,
keypoint_visibilities=keypoint_visibilities, keypoint_visibilities=keypoint_visibilities,
densepose_num_points=densepose_num_points, densepose_num_points=densepose_num_points,
...@@ -1803,6 +1819,8 @@ def random_crop_image(image, ...@@ -1803,6 +1819,8 @@ def random_crop_image(image,
outputs.append(multiclass_scores) outputs.append(multiclass_scores)
if masks is not None: if masks is not None:
outputs.append(masks) outputs.append(masks)
if mask_weights is not None:
outputs.append(mask_weights)
if keypoints is not None: if keypoints is not None:
outputs.append(keypoints) outputs.append(keypoints)
if keypoint_visibilities is not None: if keypoint_visibilities is not None:
...@@ -4388,6 +4406,7 @@ def get_default_func_arg_map(include_label_weights=True, ...@@ -4388,6 +4406,7 @@ def get_default_func_arg_map(include_label_weights=True,
include_label_confidences=False, include_label_confidences=False,
include_multiclass_scores=False, include_multiclass_scores=False,
include_instance_masks=False, include_instance_masks=False,
include_instance_mask_weights=False,
include_keypoints=False, include_keypoints=False,
include_keypoint_visibilities=False, include_keypoint_visibilities=False,
include_dense_pose=False, include_dense_pose=False,
...@@ -4403,6 +4422,8 @@ def get_default_func_arg_map(include_label_weights=True, ...@@ -4403,6 +4422,8 @@ def get_default_func_arg_map(include_label_weights=True,
multiclass scores, too. multiclass scores, too.
include_instance_masks: If True, preprocessing functions will modify the include_instance_masks: If True, preprocessing functions will modify the
instance masks, too. instance masks, too.
include_instance_mask_weights: If True, preprocessing functions will modify
the instance mask weights.
include_keypoints: If True, preprocessing functions will modify the include_keypoints: If True, preprocessing functions will modify the
keypoints, too. keypoints, too.
include_keypoint_visibilities: If True, preprocessing functions will modify include_keypoint_visibilities: If True, preprocessing functions will modify
...@@ -4434,6 +4455,11 @@ def get_default_func_arg_map(include_label_weights=True, ...@@ -4434,6 +4455,11 @@ def get_default_func_arg_map(include_label_weights=True,
groundtruth_instance_masks = ( groundtruth_instance_masks = (
fields.InputDataFields.groundtruth_instance_masks) fields.InputDataFields.groundtruth_instance_masks)
groundtruth_instance_mask_weights = None
if include_instance_mask_weights:
groundtruth_instance_mask_weights = (
fields.InputDataFields.groundtruth_instance_mask_weights)
groundtruth_keypoints = None groundtruth_keypoints = None
if include_keypoints: if include_keypoints:
groundtruth_keypoints = fields.InputDataFields.groundtruth_keypoints groundtruth_keypoints = fields.InputDataFields.groundtruth_keypoints
...@@ -4503,7 +4529,8 @@ def get_default_func_arg_map(include_label_weights=True, ...@@ -4503,7 +4529,8 @@ def get_default_func_arg_map(include_label_weights=True,
fields.InputDataFields.groundtruth_boxes, fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes, fields.InputDataFields.groundtruth_classes,
groundtruth_label_weights, groundtruth_label_confidences, groundtruth_label_weights, groundtruth_label_confidences,
multiclass_scores, groundtruth_instance_masks, groundtruth_keypoints, multiclass_scores, groundtruth_instance_masks,
groundtruth_instance_mask_weights, groundtruth_keypoints,
groundtruth_keypoint_visibilities, groundtruth_dp_num_points, groundtruth_keypoint_visibilities, groundtruth_dp_num_points,
groundtruth_dp_part_ids, groundtruth_dp_surface_coords), groundtruth_dp_part_ids, groundtruth_dp_surface_coords),
random_pad_image: random_pad_image:
......
...@@ -1894,6 +1894,37 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase): ...@@ -1894,6 +1894,37 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
self.assertAllClose( self.assertAllClose(
new_boxes.flatten(), expected_boxes.flatten()) new_boxes.flatten(), expected_boxes.flatten())
def testStrictRandomCropImageWithMaskWeights(self):
def graph_fn():
image = self.createColorfulTestImage()[0]
boxes = self.createTestBoxes()
labels = self.createTestLabels()
weights = self.createTestGroundtruthWeights()
masks = tf.random_uniform([2, 200, 400], dtype=tf.float32)
mask_weights = tf.constant([1.0, 0.0], dtype=tf.float32)
with mock.patch.object(
tf.image,
'sample_distorted_bounding_box'
) as mock_sample_distorted_bounding_box:
mock_sample_distorted_bounding_box.return_value = (
tf.constant([6, 143, 0], dtype=tf.int32),
tf.constant([190, 237, -1], dtype=tf.int32),
tf.constant([[[0.03, 0.3575, 0.98, 0.95]]], dtype=tf.float32))
results = preprocessor._strict_random_crop_image(
image, boxes, labels, weights, masks=masks,
mask_weights=mask_weights)
return results
(new_image, new_boxes, _, _,
new_masks, new_mask_weights) = self.execute_cpu(graph_fn, [])
expected_boxes = np.array(
[[0.0, 0.0, 0.75789469, 1.0],
[0.23157893, 0.24050637, 0.75789469, 1.0]], dtype=np.float32)
self.assertAllEqual(new_image.shape, [190, 237, 3])
self.assertAllEqual(new_masks.shape, [2, 190, 237])
self.assertAllClose(new_mask_weights, [1.0, 0.0])
self.assertAllClose(
new_boxes.flatten(), expected_boxes.flatten())
def testStrictRandomCropImageWithKeypoints(self): def testStrictRandomCropImageWithKeypoints(self):
def graph_fn(): def graph_fn():
image = self.createColorfulTestImage()[0] image = self.createColorfulTestImage()[0]
...@@ -1947,6 +1978,7 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase): ...@@ -1947,6 +1978,7 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
labels = self.createTestLabels() labels = self.createTestLabels()
weights = self.createTestGroundtruthWeights() weights = self.createTestGroundtruthWeights()
masks = tf.random_uniform([2, 200, 400], dtype=tf.float32) masks = tf.random_uniform([2, 200, 400], dtype=tf.float32)
mask_weights = tf.constant([1.0, 0.0], dtype=tf.float32)
tensor_dict = { tensor_dict = {
fields.InputDataFields.image: image, fields.InputDataFields.image: image,
...@@ -1954,10 +1986,12 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase): ...@@ -1954,10 +1986,12 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
fields.InputDataFields.groundtruth_classes: labels, fields.InputDataFields.groundtruth_classes: labels,
fields.InputDataFields.groundtruth_weights: weights, fields.InputDataFields.groundtruth_weights: weights,
fields.InputDataFields.groundtruth_instance_masks: masks, fields.InputDataFields.groundtruth_instance_masks: masks,
fields.InputDataFields.groundtruth_instance_mask_weights:
mask_weights
} }
preprocessor_arg_map = preprocessor.get_default_func_arg_map( preprocessor_arg_map = preprocessor.get_default_func_arg_map(
include_instance_masks=True) include_instance_masks=True, include_instance_mask_weights=True)
preprocessing_options = [(preprocessor.random_crop_image, {})] preprocessing_options = [(preprocessor.random_crop_image, {})]
...@@ -1980,16 +2014,19 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase): ...@@ -1980,16 +2014,19 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
fields.InputDataFields.groundtruth_classes] fields.InputDataFields.groundtruth_classes]
distorted_masks = distorted_tensor_dict[ distorted_masks = distorted_tensor_dict[
fields.InputDataFields.groundtruth_instance_masks] fields.InputDataFields.groundtruth_instance_masks]
distorted_mask_weights = distorted_tensor_dict[
fields.InputDataFields.groundtruth_instance_mask_weights]
return [distorted_image, distorted_boxes, distorted_labels, return [distorted_image, distorted_boxes, distorted_labels,
distorted_masks] distorted_masks, distorted_mask_weights]
(distorted_image_, distorted_boxes_, distorted_labels_, (distorted_image_, distorted_boxes_, distorted_labels_,
distorted_masks_) = self.execute_cpu(graph_fn, []) distorted_masks_, distorted_mask_weights_) = self.execute_cpu(graph_fn, [])
expected_boxes = np.array([ expected_boxes = np.array([
[0.0, 0.0, 0.75789469, 1.0], [0.0, 0.0, 0.75789469, 1.0],
[0.23157893, 0.24050637, 0.75789469, 1.0], [0.23157893, 0.24050637, 0.75789469, 1.0],
], dtype=np.float32) ], dtype=np.float32)
self.assertAllEqual(distorted_image_.shape, [1, 190, 237, 3]) self.assertAllEqual(distorted_image_.shape, [1, 190, 237, 3])
self.assertAllEqual(distorted_masks_.shape, [2, 190, 237]) self.assertAllEqual(distorted_masks_.shape, [2, 190, 237])
self.assertAllClose(distorted_mask_weights_, [1.0, 0.0])
self.assertAllEqual(distorted_labels_, [1, 2]) self.assertAllEqual(distorted_labels_, [1, 2])
self.assertAllClose( self.assertAllClose(
distorted_boxes_.flatten(), expected_boxes.flatten()) distorted_boxes_.flatten(), expected_boxes.flatten())
......
...@@ -64,6 +64,7 @@ class InputDataFields(object): ...@@ -64,6 +64,7 @@ class InputDataFields(object):
proposal_boxes: coordinates of object proposal boxes. proposal_boxes: coordinates of object proposal boxes.
proposal_objectness: objectness score of each proposal. proposal_objectness: objectness score of each proposal.
groundtruth_instance_masks: ground truth instance masks. groundtruth_instance_masks: ground truth instance masks.
groundtruth_instance_mask_weights: ground truth instance masks weights.
groundtruth_instance_boundaries: ground truth instance boundaries. groundtruth_instance_boundaries: ground truth instance boundaries.
groundtruth_instance_classes: instance mask-level class labels. groundtruth_instance_classes: instance mask-level class labels.
groundtruth_keypoints: ground truth keypoints. groundtruth_keypoints: ground truth keypoints.
...@@ -122,6 +123,7 @@ class InputDataFields(object): ...@@ -122,6 +123,7 @@ class InputDataFields(object):
proposal_boxes = 'proposal_boxes' proposal_boxes = 'proposal_boxes'
proposal_objectness = 'proposal_objectness' proposal_objectness = 'proposal_objectness'
groundtruth_instance_masks = 'groundtruth_instance_masks' groundtruth_instance_masks = 'groundtruth_instance_masks'
groundtruth_instance_mask_weights = 'groundtruth_instance_mask_weights'
groundtruth_instance_boundaries = 'groundtruth_instance_boundaries' groundtruth_instance_boundaries = 'groundtruth_instance_boundaries'
groundtruth_instance_classes = 'groundtruth_instance_classes' groundtruth_instance_classes = 'groundtruth_instance_classes'
groundtruth_keypoints = 'groundtruth_keypoints' groundtruth_keypoints = 'groundtruth_keypoints'
......
...@@ -373,6 +373,11 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -373,6 +373,11 @@ class TfExampleDecoder(data_decoder.DataDecoder):
self._decode_png_instance_masks)) self._decode_png_instance_masks))
else: else:
raise ValueError('Did not recognize the `instance_mask_type` option.') raise ValueError('Did not recognize the `instance_mask_type` option.')
self.keys_to_features['image/object/mask/weight'] = (
tf.VarLenFeature(tf.float32))
self.items_to_handlers[
fields.InputDataFields.groundtruth_instance_mask_weights] = (
slim_example_decoder.Tensor('image/object/mask/weight'))
if load_dense_pose: if load_dense_pose:
self.keys_to_features['image/object/densepose/num'] = ( self.keys_to_features['image/object/densepose/num'] = (
tf.VarLenFeature(tf.int64)) tf.VarLenFeature(tf.int64))
...@@ -491,6 +496,10 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -491,6 +496,10 @@ class TfExampleDecoder(data_decoder.DataDecoder):
tensor of shape [None, num_keypoints] containing keypoint visibilites. tensor of shape [None, num_keypoints] containing keypoint visibilites.
fields.InputDataFields.groundtruth_instance_masks - 3D float32 tensor of fields.InputDataFields.groundtruth_instance_masks - 3D float32 tensor of
shape [None, None, None] containing instance masks. shape [None, None, None] containing instance masks.
fields.InputDataFields.groundtruth_instance_mask_weights - 1D float32
tensor of shape [None] containing weights. These are typically values
in {0.0, 1.0} which indicate whether to consider the mask related to an
object.
fields.InputDataFields.groundtruth_image_classes - 1D int64 of shape fields.InputDataFields.groundtruth_image_classes - 1D int64 of shape
[None] containing classes for the boxes. [None] containing classes for the boxes.
fields.InputDataFields.multiclass_scores - 1D float32 tensor of shape fields.InputDataFields.multiclass_scores - 1D float32 tensor of shape
...@@ -531,6 +540,21 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -531,6 +540,21 @@ class TfExampleDecoder(data_decoder.DataDecoder):
0), lambda: tensor_dict[fields.InputDataFields.groundtruth_weights], 0), lambda: tensor_dict[fields.InputDataFields.groundtruth_weights],
default_groundtruth_weights) default_groundtruth_weights)
if fields.InputDataFields.groundtruth_instance_masks in tensor_dict:
gt_instance_masks = tensor_dict[
fields.InputDataFields.groundtruth_instance_masks]
num_gt_instance_masks = tf.shape(gt_instance_masks)[0]
gt_instance_mask_weights = tensor_dict[
fields.InputDataFields.groundtruth_instance_mask_weights]
num_gt_instance_mask_weights = tf.shape(gt_instance_mask_weights)[0]
def default_groundtruth_instance_mask_weights():
return tf.ones([num_gt_instance_masks], dtype=tf.float32)
tensor_dict[fields.InputDataFields.groundtruth_instance_mask_weights] = (
tf.cond(tf.greater(num_gt_instance_mask_weights, 0),
lambda: gt_instance_mask_weights,
default_groundtruth_instance_mask_weights))
if fields.InputDataFields.groundtruth_keypoints in tensor_dict: if fields.InputDataFields.groundtruth_keypoints in tensor_dict:
# Set all keypoints that are not labeled to NaN. # Set all keypoints that are not labeled to NaN.
gt_kpt_fld = fields.InputDataFields.groundtruth_keypoints gt_kpt_fld = fields.InputDataFields.groundtruth_keypoints
......
...@@ -1225,6 +1225,9 @@ class TfExampleDecoderTest(test_case.TestCase): ...@@ -1225,6 +1225,9 @@ class TfExampleDecoderTest(test_case.TestCase):
self.assertAllEqual( self.assertAllEqual(
instance_masks.astype(np.float32), instance_masks.astype(np.float32),
tensor_dict[fields.InputDataFields.groundtruth_instance_masks]) tensor_dict[fields.InputDataFields.groundtruth_instance_masks])
self.assertAllEqual(
tensor_dict[fields.InputDataFields.groundtruth_instance_mask_weights],
[1, 1, 1, 1])
self.assertAllEqual(object_classes, self.assertAllEqual(object_classes,
tensor_dict[fields.InputDataFields.groundtruth_classes]) tensor_dict[fields.InputDataFields.groundtruth_classes])
...@@ -1272,6 +1275,71 @@ class TfExampleDecoderTest(test_case.TestCase): ...@@ -1272,6 +1275,71 @@ class TfExampleDecoderTest(test_case.TestCase):
self.assertNotIn(fields.InputDataFields.groundtruth_instance_masks, self.assertNotIn(fields.InputDataFields.groundtruth_instance_masks,
tensor_dict) tensor_dict)
def testDecodeInstanceSegmentationWithWeights(self):
num_instances = 4
image_height = 5
image_width = 3
# Randomly generate image.
image_tensor = np.random.randint(
256, size=(image_height, image_width, 3)).astype(np.uint8)
encoded_jpeg, _ = self._create_encoded_and_decoded_data(
image_tensor, 'jpeg')
# Randomly generate instance segmentation masks.
instance_masks = (
np.random.randint(2, size=(num_instances, image_height,
image_width)).astype(np.float32))
instance_masks_flattened = np.reshape(instance_masks, [-1])
instance_mask_weights = np.array([1, 1, 0, 1], dtype=np.float32)
# Randomly generate class labels for each instance.
object_classes = np.random.randint(
100, size=(num_instances)).astype(np.int64)
def graph_fn():
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
dataset_util.bytes_feature(encoded_jpeg),
'image/format':
dataset_util.bytes_feature(six.b('jpeg')),
'image/height':
dataset_util.int64_feature(image_height),
'image/width':
dataset_util.int64_feature(image_width),
'image/object/mask':
dataset_util.float_list_feature(instance_masks_flattened),
'image/object/mask/weight':
dataset_util.float_list_feature(instance_mask_weights),
'image/object/class/label':
dataset_util.int64_list_feature(object_classes)
})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder(
load_instance_masks=True)
output = example_decoder.decode(tf.convert_to_tensor(example))
self.assertAllEqual(
(output[fields.InputDataFields.groundtruth_instance_masks].get_shape(
).as_list()), [4, 5, 3])
self.assertAllEqual(
output[fields.InputDataFields.groundtruth_instance_mask_weights],
[1, 1, 0, 1])
self.assertAllEqual((output[
fields.InputDataFields.groundtruth_classes].get_shape().as_list()),
[4])
return output
tensor_dict = self.execute_cpu(graph_fn, [])
self.assertAllEqual(
instance_masks.astype(np.float32),
tensor_dict[fields.InputDataFields.groundtruth_instance_masks])
self.assertAllEqual(object_classes,
tensor_dict[fields.InputDataFields.groundtruth_classes])
def testDecodeImageLabels(self): def testDecodeImageLabels(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8) image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg, _ = self._create_encoded_and_decoded_data( encoded_jpeg, _ = self._create_encoded_and_decoded_data(
......
...@@ -479,6 +479,7 @@ def pad_input_data_to_static_shapes(tensor_dict, ...@@ -479,6 +479,7 @@ def pad_input_data_to_static_shapes(tensor_dict,
input_fields.groundtruth_instance_masks: [ input_fields.groundtruth_instance_masks: [
max_num_boxes, height, width max_num_boxes, height, width
], ],
input_fields.groundtruth_instance_mask_weights: [max_num_boxes],
input_fields.groundtruth_is_crowd: [max_num_boxes], input_fields.groundtruth_is_crowd: [max_num_boxes],
input_fields.groundtruth_group_of: [max_num_boxes], input_fields.groundtruth_group_of: [max_num_boxes],
input_fields.groundtruth_area: [max_num_boxes], input_fields.groundtruth_area: [max_num_boxes],
...@@ -601,6 +602,8 @@ def augment_input_data(tensor_dict, data_augmentation_options): ...@@ -601,6 +602,8 @@ def augment_input_data(tensor_dict, data_augmentation_options):
include_instance_masks = (fields.InputDataFields.groundtruth_instance_masks include_instance_masks = (fields.InputDataFields.groundtruth_instance_masks
in tensor_dict) in tensor_dict)
include_instance_mask_weights = (
fields.InputDataFields.groundtruth_instance_mask_weights in tensor_dict)
include_keypoints = (fields.InputDataFields.groundtruth_keypoints include_keypoints = (fields.InputDataFields.groundtruth_keypoints
in tensor_dict) in tensor_dict)
include_keypoint_visibilities = ( include_keypoint_visibilities = (
...@@ -624,6 +627,7 @@ def augment_input_data(tensor_dict, data_augmentation_options): ...@@ -624,6 +627,7 @@ def augment_input_data(tensor_dict, data_augmentation_options):
include_label_confidences=include_label_confidences, include_label_confidences=include_label_confidences,
include_multiclass_scores=include_multiclass_scores, include_multiclass_scores=include_multiclass_scores,
include_instance_masks=include_instance_masks, include_instance_masks=include_instance_masks,
include_instance_mask_weights=include_instance_mask_weights,
include_keypoints=include_keypoints, include_keypoints=include_keypoints,
include_keypoint_visibilities=include_keypoint_visibilities, include_keypoint_visibilities=include_keypoint_visibilities,
include_dense_pose=include_dense_pose, include_dense_pose=include_dense_pose,
...@@ -652,6 +656,7 @@ def _get_labels_dict(input_dict): ...@@ -652,6 +656,7 @@ def _get_labels_dict(input_dict):
fields.InputDataFields.groundtruth_keypoint_depths, fields.InputDataFields.groundtruth_keypoint_depths,
fields.InputDataFields.groundtruth_keypoint_depth_weights, fields.InputDataFields.groundtruth_keypoint_depth_weights,
fields.InputDataFields.groundtruth_instance_masks, fields.InputDataFields.groundtruth_instance_masks,
fields.InputDataFields.groundtruth_instance_mask_weights,
fields.InputDataFields.groundtruth_area, fields.InputDataFields.groundtruth_area,
fields.InputDataFields.groundtruth_is_crowd, fields.InputDataFields.groundtruth_is_crowd,
fields.InputDataFields.groundtruth_group_of, fields.InputDataFields.groundtruth_group_of,
...@@ -804,6 +809,9 @@ def train_input(train_config, train_input_config, ...@@ -804,6 +809,9 @@ def train_input(train_config, train_input_config,
labels[fields.InputDataFields.groundtruth_instance_masks] is a labels[fields.InputDataFields.groundtruth_instance_masks] is a
[batch_size, num_boxes, H, W] float32 tensor containing only binary [batch_size, num_boxes, H, W] float32 tensor containing only binary
values, which represent instance masks for objects. values, which represent instance masks for objects.
labels[fields.InputDataFields.groundtruth_instance_mask_weights] is a
[batch_size, num_boxes] float32 tensor containing groundtruth weights
for each instance mask.
labels[fields.InputDataFields.groundtruth_keypoints] is a labels[fields.InputDataFields.groundtruth_keypoints] is a
[batch_size, num_boxes, num_keypoints, 2] float32 tensor containing [batch_size, num_boxes, num_keypoints, 2] float32 tensor containing
keypoints for each box. keypoints for each box.
...@@ -961,6 +969,9 @@ def eval_input(eval_config, eval_input_config, model_config, ...@@ -961,6 +969,9 @@ def eval_input(eval_config, eval_input_config, model_config,
labels[fields.InputDataFields.groundtruth_instance_masks] is a labels[fields.InputDataFields.groundtruth_instance_masks] is a
[1, num_boxes, H, W] float32 tensor containing only binary values, [1, num_boxes, H, W] float32 tensor containing only binary values,
which represent instance masks for objects. which represent instance masks for objects.
labels[fields.InputDataFields.groundtruth_instance_mask_weights] is a
[1, num_boxes] float32 tensor containing groundtruth weights for each
instance mask.
labels[fields.InputDataFields.groundtruth_weights] is a labels[fields.InputDataFields.groundtruth_weights] is a
[batch_size, num_boxes, num_keypoints] float32 tensor containing [batch_size, num_boxes, num_keypoints] float32 tensor containing
groundtruth weights for the keypoints. groundtruth weights for the keypoints.
......
...@@ -795,15 +795,20 @@ class DataAugmentationFnTest(test_case.TestCase): ...@@ -795,15 +795,20 @@ class DataAugmentationFnTest(test_case.TestCase):
fields.InputDataFields.image: fields.InputDataFields.image:
tf.constant(np.random.rand(10, 10, 3).astype(np.float32)), tf.constant(np.random.rand(10, 10, 3).astype(np.float32)),
fields.InputDataFields.groundtruth_instance_masks: fields.InputDataFields.groundtruth_instance_masks:
tf.constant(np.zeros([2, 10, 10], np.uint8)) tf.constant(np.zeros([2, 10, 10], np.uint8)),
fields.InputDataFields.groundtruth_instance_mask_weights:
tf.constant([1.0, 0.0], np.float32)
} }
augmented_tensor_dict = data_augmentation_fn(tensor_dict=tensor_dict) augmented_tensor_dict = data_augmentation_fn(tensor_dict=tensor_dict)
return (augmented_tensor_dict[fields.InputDataFields.image], return (augmented_tensor_dict[fields.InputDataFields.image],
augmented_tensor_dict[fields.InputDataFields. augmented_tensor_dict[fields.InputDataFields.
groundtruth_instance_masks]) groundtruth_instance_masks],
image, masks = self.execute_cpu(graph_fn, []) augmented_tensor_dict[fields.InputDataFields.
groundtruth_instance_mask_weights])
image, masks, mask_weights = self.execute_cpu(graph_fn, [])
self.assertAllEqual(image.shape, [20, 20, 3]) self.assertAllEqual(image.shape, [20, 20, 3])
self.assertAllEqual(masks.shape, [2, 20, 20]) self.assertAllEqual(masks.shape, [2, 20, 20])
self.assertAllClose(mask_weights, [1.0, 0.0])
def test_include_keypoints_in_data_augmentation(self): def test_include_keypoints_in_data_augmentation(self):
data_augmentation_options = [ data_augmentation_options = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment