Commit 624ac92d authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Updated the tf_example decoder and decoder builders to parse the keypoint...

Updated the tf_example decoder and decoder builders to parse the keypoint depth features from the input tf.Example.

PiperOrigin-RevId: 353305509
parent ca9cf75f
...@@ -60,7 +60,9 @@ def build(input_reader_config): ...@@ -60,7 +60,9 @@ def build(input_reader_config):
num_keypoints=input_reader_config.num_keypoints, num_keypoints=input_reader_config.num_keypoints,
expand_hierarchy_labels=input_reader_config.expand_labels_hierarchy, expand_hierarchy_labels=input_reader_config.expand_labels_hierarchy,
load_dense_pose=input_reader_config.load_dense_pose, load_dense_pose=input_reader_config.load_dense_pose,
load_track_id=input_reader_config.load_track_id) load_track_id=input_reader_config.load_track_id,
load_keypoint_depth_features=input_reader_config
.load_keypoint_depth_features)
return decoder return decoder
elif input_type == input_reader_pb2.InputType.Value('TF_SEQUENCE_EXAMPLE'): elif input_type == input_reader_pb2.InputType.Value('TF_SEQUENCE_EXAMPLE'):
decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder( decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder(
......
...@@ -65,6 +65,8 @@ class DecoderBuilderTest(test_case.TestCase): ...@@ -65,6 +65,8 @@ class DecoderBuilderTest(test_case.TestCase):
'image/object/bbox/ymax': dataset_util.float_list_feature([1.0]), 'image/object/bbox/ymax': dataset_util.float_list_feature([1.0]),
'image/object/class/label': dataset_util.int64_list_feature([2]), 'image/object/class/label': dataset_util.int64_list_feature([2]),
'image/object/mask': dataset_util.float_list_feature(flat_mask), 'image/object/mask': dataset_util.float_list_feature(flat_mask),
'image/object/keypoint/x': dataset_util.float_list_feature([1.0, 1.0]),
'image/object/keypoint/y': dataset_util.float_list_feature([1.0, 1.0])
} }
if has_additional_channels: if has_additional_channels:
additional_channels_key = 'image/additional_channels/encoded' additional_channels_key = 'image/additional_channels/encoded'
...@@ -188,6 +190,28 @@ class DecoderBuilderTest(test_case.TestCase): ...@@ -188,6 +190,28 @@ class DecoderBuilderTest(test_case.TestCase):
masks = self.execute_cpu(graph_fn, []) masks = self.execute_cpu(graph_fn, [])
self.assertAllEqual((1, 4, 5), masks.shape) self.assertAllEqual((1, 4, 5), masks.shape)
def test_build_tf_record_input_reader_and_load_keypoint_depth(self):
input_reader_text_proto = """
load_keypoint_depth_features: true
num_keypoints: 2
tf_record_input_reader {}
"""
input_reader_proto = input_reader_pb2.InputReader()
text_format.Parse(input_reader_text_proto, input_reader_proto)
decoder = decoder_builder.build(input_reader_proto)
serialized_example = self._make_serialized_tf_example()
def graph_fn():
tensor_dict = decoder.decode(serialized_example)
return (tensor_dict[fields.InputDataFields.groundtruth_keypoint_depths],
tensor_dict[
fields.InputDataFields.groundtruth_keypoint_depth_weights])
(kpts_depths, kpts_depth_weights) = self.execute_cpu(graph_fn, [])
self.assertAllEqual((1, 2), kpts_depths.shape)
self.assertAllEqual((1, 2), kpts_depth_weights.shape)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -67,6 +67,9 @@ class InputDataFields(object): ...@@ -67,6 +67,9 @@ class InputDataFields(object):
groundtruth_instance_boundaries: ground truth instance boundaries. groundtruth_instance_boundaries: ground truth instance boundaries.
groundtruth_instance_classes: instance mask-level class labels. groundtruth_instance_classes: instance mask-level class labels.
groundtruth_keypoints: ground truth keypoints. groundtruth_keypoints: ground truth keypoints.
groundtruth_keypoint_depths: Relative depth of the keypoints.
groundtruth_keypoint_depth_weights: Weights of the relative depth of the
keypoints.
groundtruth_keypoint_visibilities: ground truth keypoint visibilities. groundtruth_keypoint_visibilities: ground truth keypoint visibilities.
groundtruth_keypoint_weights: groundtruth weight factor for keypoints. groundtruth_keypoint_weights: groundtruth weight factor for keypoints.
groundtruth_label_weights: groundtruth label weights. groundtruth_label_weights: groundtruth label weights.
...@@ -122,6 +125,8 @@ class InputDataFields(object): ...@@ -122,6 +125,8 @@ class InputDataFields(object):
groundtruth_instance_boundaries = 'groundtruth_instance_boundaries' groundtruth_instance_boundaries = 'groundtruth_instance_boundaries'
groundtruth_instance_classes = 'groundtruth_instance_classes' groundtruth_instance_classes = 'groundtruth_instance_classes'
groundtruth_keypoints = 'groundtruth_keypoints' groundtruth_keypoints = 'groundtruth_keypoints'
groundtruth_keypoint_depths = 'groundtruth_keypoint_depths'
groundtruth_keypoint_depth_weights = 'groundtruth_keypoint_depth_weights'
groundtruth_keypoint_visibilities = 'groundtruth_keypoint_visibilities' groundtruth_keypoint_visibilities = 'groundtruth_keypoint_visibilities'
groundtruth_keypoint_weights = 'groundtruth_keypoint_weights' groundtruth_keypoint_weights = 'groundtruth_keypoint_weights'
groundtruth_label_weights = 'groundtruth_label_weights' groundtruth_label_weights = 'groundtruth_label_weights'
...@@ -162,6 +167,7 @@ class DetectionResultFields(object): ...@@ -162,6 +167,7 @@ class DetectionResultFields(object):
detection_boundaries: contains an object boundary for each detection box. detection_boundaries: contains an object boundary for each detection box.
detection_keypoints: contains detection keypoints for each detection box. detection_keypoints: contains detection keypoints for each detection box.
detection_keypoint_scores: contains detection keypoint scores. detection_keypoint_scores: contains detection keypoint scores.
detection_keypoint_depths: contains detection keypoint depths.
num_detections: number of detections in the batch. num_detections: number of detections in the batch.
raw_detection_boxes: contains decoded detection boxes without Non-Max raw_detection_boxes: contains decoded detection boxes without Non-Max
suppression. suppression.
...@@ -183,6 +189,7 @@ class DetectionResultFields(object): ...@@ -183,6 +189,7 @@ class DetectionResultFields(object):
detection_boundaries = 'detection_boundaries' detection_boundaries = 'detection_boundaries'
detection_keypoints = 'detection_keypoints' detection_keypoints = 'detection_keypoints'
detection_keypoint_scores = 'detection_keypoint_scores' detection_keypoint_scores = 'detection_keypoint_scores'
detection_keypoint_depths = 'detection_keypoint_depths'
detection_embeddings = 'detection_embeddings' detection_embeddings = 'detection_embeddings'
detection_offsets = 'detection_temporal_offsets' detection_offsets = 'detection_temporal_offsets'
num_detections = 'num_detections' num_detections = 'num_detections'
...@@ -205,6 +212,8 @@ class BoxListFields(object): ...@@ -205,6 +212,8 @@ class BoxListFields(object):
keypoints: keypoints per bounding box. keypoints: keypoints per bounding box.
keypoint_visibilities: keypoint visibilities per bounding box. keypoint_visibilities: keypoint visibilities per bounding box.
keypoint_heatmaps: keypoint heatmaps per bounding box. keypoint_heatmaps: keypoint heatmaps per bounding box.
keypoint_depths: keypoint depths per bounding box.
keypoint_depth_weights: keypoint depth weights per bounding box.
densepose_num_points: number of DensePose points per bounding box. densepose_num_points: number of DensePose points per bounding box.
densepose_part_ids: DensePose part ids per bounding box. densepose_part_ids: DensePose part ids per bounding box.
densepose_surface_coords: DensePose surface coordinates per bounding box. densepose_surface_coords: DensePose surface coordinates per bounding box.
...@@ -223,6 +232,8 @@ class BoxListFields(object): ...@@ -223,6 +232,8 @@ class BoxListFields(object):
keypoints = 'keypoints' keypoints = 'keypoints'
keypoint_visibilities = 'keypoint_visibilities' keypoint_visibilities = 'keypoint_visibilities'
keypoint_heatmaps = 'keypoint_heatmaps' keypoint_heatmaps = 'keypoint_heatmaps'
keypoint_depths = 'keypoint_depths'
keypoint_depth_weights = 'keypoint_depth_weights'
densepose_num_points = 'densepose_num_points' densepose_num_points = 'densepose_num_points'
densepose_part_ids = 'densepose_part_ids' densepose_part_ids = 'densepose_part_ids'
densepose_surface_coords = 'densepose_surface_coords' densepose_surface_coords = 'densepose_surface_coords'
......
...@@ -139,7 +139,8 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -139,7 +139,8 @@ class TfExampleDecoder(data_decoder.DataDecoder):
load_context_features=False, load_context_features=False,
expand_hierarchy_labels=False, expand_hierarchy_labels=False,
load_dense_pose=False, load_dense_pose=False,
load_track_id=False): load_track_id=False,
load_keypoint_depth_features=False):
"""Constructor sets keys_to_features and items_to_handlers. """Constructor sets keys_to_features and items_to_handlers.
Args: Args:
...@@ -172,6 +173,10 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -172,6 +173,10 @@ class TfExampleDecoder(data_decoder.DataDecoder):
the labels are expanded to descendants. the labels are expanded to descendants.
load_dense_pose: Whether to load DensePose annotations. load_dense_pose: Whether to load DensePose annotations.
load_track_id: Whether to load tracking annotations. load_track_id: Whether to load tracking annotations.
load_keypoint_depth_features: Whether to load the keypoint depth features
including keypoint relative depths and weights. If this field is set to
True but no keypoint depth features are in the input tf.Example, then
default values will be populated.
Raises: Raises:
ValueError: If `instance_mask_type` option is not one of ValueError: If `instance_mask_type` option is not one of
...@@ -180,6 +185,7 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -180,6 +185,7 @@ class TfExampleDecoder(data_decoder.DataDecoder):
ValueError: If `expand_labels_hierarchy` is True, but the ValueError: If `expand_labels_hierarchy` is True, but the
`label_map_proto_file` is not provided. `label_map_proto_file` is not provided.
""" """
# TODO(rathodv): delete unused `use_display_name` argument once we change # TODO(rathodv): delete unused `use_display_name` argument once we change
# other decoders to handle label maps similarly. # other decoders to handle label maps similarly.
del use_display_name del use_display_name
...@@ -331,6 +337,23 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -331,6 +337,23 @@ class TfExampleDecoder(data_decoder.DataDecoder):
slim_example_decoder.ItemHandlerCallback( slim_example_decoder.ItemHandlerCallback(
['image/object/keypoint/x', 'image/object/keypoint/visibility'], ['image/object/keypoint/x', 'image/object/keypoint/visibility'],
self._reshape_keypoint_visibilities)) self._reshape_keypoint_visibilities))
if load_keypoint_depth_features:
self.keys_to_features['image/object/keypoint/z'] = (
tf.VarLenFeature(tf.float32))
self.keys_to_features['image/object/keypoint/z/weights'] = (
tf.VarLenFeature(tf.float32))
self.items_to_handlers[
fields.InputDataFields.groundtruth_keypoint_depths] = (
slim_example_decoder.ItemHandlerCallback(
['image/object/keypoint/x', 'image/object/keypoint/z'],
self._reshape_keypoint_depths))
self.items_to_handlers[
fields.InputDataFields.groundtruth_keypoint_depth_weights] = (
slim_example_decoder.ItemHandlerCallback(
['image/object/keypoint/x',
'image/object/keypoint/z/weights'],
self._reshape_keypoint_depth_weights))
if load_instance_masks: if load_instance_masks:
if instance_mask_type in (input_reader_pb2.DEFAULT, if instance_mask_type in (input_reader_pb2.DEFAULT,
input_reader_pb2.NUMERICAL_MASKS): input_reader_pb2.NUMERICAL_MASKS):
...@@ -601,6 +624,73 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -601,6 +624,73 @@ class TfExampleDecoder(data_decoder.DataDecoder):
keypoints = tf.reshape(keypoints, [-1, self._num_keypoints, 2]) keypoints = tf.reshape(keypoints, [-1, self._num_keypoints, 2])
return keypoints return keypoints
def _reshape_keypoint_depths(self, keys_to_tensors):
"""Reshape keypoint depths.
The keypoint depths are reshaped to [num_instances, num_keypoints]. The
keypoint depth tensor is expected to have the same shape as the keypoint x
(or y) tensors. If not (usually because the example does not have the depth
groundtruth), then default depth values (zero) are provided.
Args:
keys_to_tensors: a dictionary from keys to tensors. Expected keys are:
'image/object/keypoint/x'
'image/object/keypoint/z'
Returns:
A 2-D float tensor of shape [num_instances, num_keypoints] with values
representing the keypoint depths.
"""
x = keys_to_tensors['image/object/keypoint/x']
z = keys_to_tensors['image/object/keypoint/z']
if isinstance(z, tf.SparseTensor):
z = tf.sparse_tensor_to_dense(z)
if isinstance(x, tf.SparseTensor):
x = tf.sparse_tensor_to_dense(x)
default_z = tf.zeros_like(x)
# Use keypoint depth groundtruth if provided, otherwise use the default
# depth value.
z = tf.cond(tf.equal(tf.size(x), tf.size(z)),
true_fn=lambda: z,
false_fn=lambda: default_z)
z = tf.reshape(z, [-1, self._num_keypoints])
return z
def _reshape_keypoint_depth_weights(self, keys_to_tensors):
"""Reshape keypoint depth weights.
The keypoint depth weights are reshaped to [num_instances, num_keypoints].
The keypoint depth weights tensor is expected to have the same shape as the
keypoint x (or y) tensors. If not (usually because the example does not have
the depth weights groundtruth), then default weight values (zero) are
provided.
Args:
keys_to_tensors: a dictionary from keys to tensors. Expected keys are:
'image/object/keypoint/x'
'image/object/keypoint/z/weights'
Returns:
A 2-D float tensor of shape [num_instances, num_keypoints] with values
representing the keypoint depth weights.
"""
x = keys_to_tensors['image/object/keypoint/x']
z = keys_to_tensors['image/object/keypoint/z/weights']
if isinstance(z, tf.SparseTensor):
z = tf.sparse_tensor_to_dense(z)
if isinstance(x, tf.SparseTensor):
x = tf.sparse_tensor_to_dense(x)
default_z = tf.zeros_like(x)
# Use keypoint depth weights if provided, otherwise use the default
# values.
z = tf.cond(tf.equal(tf.size(x), tf.size(z)),
true_fn=lambda: z,
false_fn=lambda: default_z)
z = tf.reshape(z, [-1, self._num_keypoints])
return z
def _reshape_keypoint_visibilities(self, keys_to_tensors): def _reshape_keypoint_visibilities(self, keys_to_tensors):
"""Reshape keypoint visibilities. """Reshape keypoint visibilities.
......
...@@ -275,6 +275,124 @@ class TfExampleDecoderTest(test_case.TestCase): ...@@ -275,6 +275,124 @@ class TfExampleDecoderTest(test_case.TestCase):
self.assertAllEqual(expected_boxes, self.assertAllEqual(expected_boxes,
tensor_dict[fields.InputDataFields.groundtruth_boxes]) tensor_dict[fields.InputDataFields.groundtruth_boxes])
def testDecodeKeypointDepth(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg, _ = self._create_encoded_and_decoded_data(
image_tensor, 'jpeg')
bbox_ymins = [0.0, 4.0]
bbox_xmins = [1.0, 5.0]
bbox_ymaxs = [2.0, 6.0]
bbox_xmaxs = [3.0, 7.0]
keypoint_ys = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
keypoint_xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
keypoint_visibility = [1, 2, 0, 1, 0, 2]
keypoint_depths = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
keypoint_depth_weights = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5]
def graph_fn():
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
dataset_util.bytes_feature(encoded_jpeg),
'image/format':
dataset_util.bytes_feature(six.b('jpeg')),
'image/object/bbox/ymin':
dataset_util.float_list_feature(bbox_ymins),
'image/object/bbox/xmin':
dataset_util.float_list_feature(bbox_xmins),
'image/object/bbox/ymax':
dataset_util.float_list_feature(bbox_ymaxs),
'image/object/bbox/xmax':
dataset_util.float_list_feature(bbox_xmaxs),
'image/object/keypoint/y':
dataset_util.float_list_feature(keypoint_ys),
'image/object/keypoint/x':
dataset_util.float_list_feature(keypoint_xs),
'image/object/keypoint/z':
dataset_util.float_list_feature(keypoint_depths),
'image/object/keypoint/z/weights':
dataset_util.float_list_feature(keypoint_depth_weights),
'image/object/keypoint/visibility':
dataset_util.int64_list_feature(keypoint_visibility),
})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder(
num_keypoints=3, load_keypoint_depth_features=True)
output = example_decoder.decode(tf.convert_to_tensor(example))
self.assertAllEqual(
(output[fields.InputDataFields.groundtruth_keypoint_depths].get_shape(
).as_list()), [2, 3])
self.assertAllEqual(
(output[fields.InputDataFields.groundtruth_keypoint_depth_weights]
.get_shape().as_list()), [2, 3])
return output
tensor_dict = self.execute_cpu(graph_fn, [])
expected_keypoint_depths = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
self.assertAllClose(
expected_keypoint_depths,
tensor_dict[fields.InputDataFields.groundtruth_keypoint_depths])
expected_keypoint_depth_weights = [[1.0, 0.9, 0.8], [0.7, 0.6, 0.5]]
self.assertAllClose(
expected_keypoint_depth_weights,
tensor_dict[fields.InputDataFields.groundtruth_keypoint_depth_weights])
def testDecodeKeypointDepthNoDepth(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg, _ = self._create_encoded_and_decoded_data(
image_tensor, 'jpeg')
bbox_ymins = [0.0, 4.0]
bbox_xmins = [1.0, 5.0]
bbox_ymaxs = [2.0, 6.0]
bbox_xmaxs = [3.0, 7.0]
keypoint_ys = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
keypoint_xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
keypoint_visibility = [1, 2, 0, 1, 0, 2]
def graph_fn():
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
dataset_util.bytes_feature(encoded_jpeg),
'image/format':
dataset_util.bytes_feature(six.b('jpeg')),
'image/object/bbox/ymin':
dataset_util.float_list_feature(bbox_ymins),
'image/object/bbox/xmin':
dataset_util.float_list_feature(bbox_xmins),
'image/object/bbox/ymax':
dataset_util.float_list_feature(bbox_ymaxs),
'image/object/bbox/xmax':
dataset_util.float_list_feature(bbox_xmaxs),
'image/object/keypoint/y':
dataset_util.float_list_feature(keypoint_ys),
'image/object/keypoint/x':
dataset_util.float_list_feature(keypoint_xs),
'image/object/keypoint/visibility':
dataset_util.int64_list_feature(keypoint_visibility),
})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder(
num_keypoints=3, load_keypoint_depth_features=True)
output = example_decoder.decode(tf.convert_to_tensor(example))
return output
tensor_dict = self.execute_cpu(graph_fn, [])
expected_keypoints_depth_default = [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
self.assertAllClose(
expected_keypoints_depth_default,
tensor_dict[fields.InputDataFields.groundtruth_keypoint_depths])
self.assertAllClose(
expected_keypoints_depth_default,
tensor_dict[fields.InputDataFields.groundtruth_keypoint_depth_weights])
def testDecodeKeypoint(self): def testDecodeKeypoint(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8) image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg, _ = self._create_encoded_and_decoded_data( encoded_jpeg, _ = self._create_encoded_and_decoded_data(
......
...@@ -30,7 +30,7 @@ enum InputType { ...@@ -30,7 +30,7 @@ enum InputType {
TF_SEQUENCE_EXAMPLE = 2; // TfSequenceExample Input TF_SEQUENCE_EXAMPLE = 2; // TfSequenceExample Input
} }
// Next id: 37 // Next id: 38
message InputReader { message InputReader {
// Name of input reader. Typically used to describe the dataset that is read // Name of input reader. Typically used to describe the dataset that is read
// by this input reader. // by this input reader.
...@@ -134,6 +134,9 @@ message InputReader { ...@@ -134,6 +134,9 @@ message InputReader {
// Whether to load track information. // Whether to load track information.
optional bool load_track_id = 33 [default = false]; optional bool load_track_id = 33 [default = false];
// Whether to load keypoint depth features.
optional bool load_keypoint_depth_features = 37 [default = false];
// Whether to use the display name when decoding examples. This is only used // Whether to use the display name when decoding examples. This is only used
// when mapping class text strings to integers. // when mapping class text strings to integers.
optional bool use_display_name = 17 [default = false]; optional bool use_display_name = 17 [default = false];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment