Commit 624ac92d authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Updated the tf_example decoder and decoder builders to parse the keypoint...

Updated the tf_example decoder and decoder builders to parse the keypoint depth features from the input tf.Example.

PiperOrigin-RevId: 353305509
parent ca9cf75f
......@@ -60,7 +60,9 @@ def build(input_reader_config):
num_keypoints=input_reader_config.num_keypoints,
expand_hierarchy_labels=input_reader_config.expand_labels_hierarchy,
load_dense_pose=input_reader_config.load_dense_pose,
load_track_id=input_reader_config.load_track_id)
load_track_id=input_reader_config.load_track_id,
load_keypoint_depth_features=input_reader_config
.load_keypoint_depth_features)
return decoder
elif input_type == input_reader_pb2.InputType.Value('TF_SEQUENCE_EXAMPLE'):
decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder(
......
......@@ -65,6 +65,8 @@ class DecoderBuilderTest(test_case.TestCase):
'image/object/bbox/ymax': dataset_util.float_list_feature([1.0]),
'image/object/class/label': dataset_util.int64_list_feature([2]),
'image/object/mask': dataset_util.float_list_feature(flat_mask),
'image/object/keypoint/x': dataset_util.float_list_feature([1.0, 1.0]),
'image/object/keypoint/y': dataset_util.float_list_feature([1.0, 1.0])
}
if has_additional_channels:
additional_channels_key = 'image/additional_channels/encoded'
......@@ -188,6 +190,28 @@ class DecoderBuilderTest(test_case.TestCase):
masks = self.execute_cpu(graph_fn, [])
self.assertAllEqual((1, 4, 5), masks.shape)
def test_build_tf_record_input_reader_and_load_keypoint_depth(self):
input_reader_text_proto = """
load_keypoint_depth_features: true
num_keypoints: 2
tf_record_input_reader {}
"""
input_reader_proto = input_reader_pb2.InputReader()
text_format.Parse(input_reader_text_proto, input_reader_proto)
decoder = decoder_builder.build(input_reader_proto)
serialized_example = self._make_serialized_tf_example()
def graph_fn():
tensor_dict = decoder.decode(serialized_example)
return (tensor_dict[fields.InputDataFields.groundtruth_keypoint_depths],
tensor_dict[
fields.InputDataFields.groundtruth_keypoint_depth_weights])
(kpts_depths, kpts_depth_weights) = self.execute_cpu(graph_fn, [])
self.assertAllEqual((1, 2), kpts_depths.shape)
self.assertAllEqual((1, 2), kpts_depth_weights.shape)
if __name__ == '__main__':
tf.test.main()
......@@ -67,6 +67,9 @@ class InputDataFields(object):
groundtruth_instance_boundaries: ground truth instance boundaries.
groundtruth_instance_classes: instance mask-level class labels.
groundtruth_keypoints: ground truth keypoints.
groundtruth_keypoint_depths: Relative depth of the keypoints.
groundtruth_keypoint_depth_weights: Weights of the relative depth of the
keypoints.
groundtruth_keypoint_visibilities: ground truth keypoint visibilities.
groundtruth_keypoint_weights: groundtruth weight factor for keypoints.
groundtruth_label_weights: groundtruth label weights.
......@@ -122,6 +125,8 @@ class InputDataFields(object):
groundtruth_instance_boundaries = 'groundtruth_instance_boundaries'
groundtruth_instance_classes = 'groundtruth_instance_classes'
groundtruth_keypoints = 'groundtruth_keypoints'
groundtruth_keypoint_depths = 'groundtruth_keypoint_depths'
groundtruth_keypoint_depth_weights = 'groundtruth_keypoint_depth_weights'
groundtruth_keypoint_visibilities = 'groundtruth_keypoint_visibilities'
groundtruth_keypoint_weights = 'groundtruth_keypoint_weights'
groundtruth_label_weights = 'groundtruth_label_weights'
......@@ -162,6 +167,7 @@ class DetectionResultFields(object):
detection_boundaries: contains an object boundary for each detection box.
detection_keypoints: contains detection keypoints for each detection box.
detection_keypoint_scores: contains detection keypoint scores.
detection_keypoint_depths: contains detection keypoint depths.
num_detections: number of detections in the batch.
raw_detection_boxes: contains decoded detection boxes without Non-Max
suppression.
......@@ -183,6 +189,7 @@ class DetectionResultFields(object):
detection_boundaries = 'detection_boundaries'
detection_keypoints = 'detection_keypoints'
detection_keypoint_scores = 'detection_keypoint_scores'
detection_keypoint_depths = 'detection_keypoint_depths'
detection_embeddings = 'detection_embeddings'
detection_offsets = 'detection_temporal_offsets'
num_detections = 'num_detections'
......@@ -205,6 +212,8 @@ class BoxListFields(object):
keypoints: keypoints per bounding box.
keypoint_visibilities: keypoint visibilities per bounding box.
keypoint_heatmaps: keypoint heatmaps per bounding box.
keypoint_depths: keypoint depths per bounding box.
keypoint_depth_weights: keypoint depth weights per bounding box.
densepose_num_points: number of DensePose points per bounding box.
densepose_part_ids: DensePose part ids per bounding box.
densepose_surface_coords: DensePose surface coordinates per bounding box.
......@@ -223,6 +232,8 @@ class BoxListFields(object):
keypoints = 'keypoints'
keypoint_visibilities = 'keypoint_visibilities'
keypoint_heatmaps = 'keypoint_heatmaps'
keypoint_depths = 'keypoint_depths'
keypoint_depth_weights = 'keypoint_depth_weights'
densepose_num_points = 'densepose_num_points'
densepose_part_ids = 'densepose_part_ids'
densepose_surface_coords = 'densepose_surface_coords'
......
......@@ -139,7 +139,8 @@ class TfExampleDecoder(data_decoder.DataDecoder):
load_context_features=False,
expand_hierarchy_labels=False,
load_dense_pose=False,
load_track_id=False):
load_track_id=False,
load_keypoint_depth_features=False):
"""Constructor sets keys_to_features and items_to_handlers.
Args:
......@@ -172,6 +173,10 @@ class TfExampleDecoder(data_decoder.DataDecoder):
the labels are expanded to descendants.
load_dense_pose: Whether to load DensePose annotations.
load_track_id: Whether to load tracking annotations.
load_keypoint_depth_features: Whether to load the keypoint depth features
including keypoint relative depths and weights. If this field is set to
True but no keypoint depth features are in the input tf.Example, then
default values will be populated.
Raises:
ValueError: If `instance_mask_type` option is not one of
......@@ -180,6 +185,7 @@ class TfExampleDecoder(data_decoder.DataDecoder):
ValueError: If `expand_labels_hierarchy` is True, but the
`label_map_proto_file` is not provided.
"""
# TODO(rathodv): delete unused `use_display_name` argument once we change
# other decoders to handle label maps similarly.
del use_display_name
......@@ -331,6 +337,23 @@ class TfExampleDecoder(data_decoder.DataDecoder):
slim_example_decoder.ItemHandlerCallback(
['image/object/keypoint/x', 'image/object/keypoint/visibility'],
self._reshape_keypoint_visibilities))
if load_keypoint_depth_features:
self.keys_to_features['image/object/keypoint/z'] = (
tf.VarLenFeature(tf.float32))
self.keys_to_features['image/object/keypoint/z/weights'] = (
tf.VarLenFeature(tf.float32))
self.items_to_handlers[
fields.InputDataFields.groundtruth_keypoint_depths] = (
slim_example_decoder.ItemHandlerCallback(
['image/object/keypoint/x', 'image/object/keypoint/z'],
self._reshape_keypoint_depths))
self.items_to_handlers[
fields.InputDataFields.groundtruth_keypoint_depth_weights] = (
slim_example_decoder.ItemHandlerCallback(
['image/object/keypoint/x',
'image/object/keypoint/z/weights'],
self._reshape_keypoint_depth_weights))
if load_instance_masks:
if instance_mask_type in (input_reader_pb2.DEFAULT,
input_reader_pb2.NUMERICAL_MASKS):
......@@ -601,6 +624,73 @@ class TfExampleDecoder(data_decoder.DataDecoder):
keypoints = tf.reshape(keypoints, [-1, self._num_keypoints, 2])
return keypoints
def _reshape_keypoint_depths(self, keys_to_tensors):
"""Reshape keypoint depths.
The keypoint depths are reshaped to [num_instances, num_keypoints]. The
keypoint depth tensor is expected to have the same shape as the keypoint x
(or y) tensors. If not (usually because the example does not have the depth
groundtruth), then default depth values (zero) are provided.
Args:
keys_to_tensors: a dictionary from keys to tensors. Expected keys are:
'image/object/keypoint/x'
'image/object/keypoint/z'
Returns:
A 2-D float tensor of shape [num_instances, num_keypoints] with values
representing the keypoint depths.
"""
x = keys_to_tensors['image/object/keypoint/x']
z = keys_to_tensors['image/object/keypoint/z']
if isinstance(z, tf.SparseTensor):
z = tf.sparse_tensor_to_dense(z)
if isinstance(x, tf.SparseTensor):
x = tf.sparse_tensor_to_dense(x)
default_z = tf.zeros_like(x)
# Use keypoint depth groundtruth if provided, otherwise use the default
# depth value.
z = tf.cond(tf.equal(tf.size(x), tf.size(z)),
true_fn=lambda: z,
false_fn=lambda: default_z)
z = tf.reshape(z, [-1, self._num_keypoints])
return z
def _reshape_keypoint_depth_weights(self, keys_to_tensors):
"""Reshape keypoint depth weights.
The keypoint depth weights are reshaped to [num_instances, num_keypoints].
The keypoint depth weights tensor is expected to have the same shape as the
keypoint x (or y) tensors. If not (usually because the example does not have
the depth weights groundtruth), then default weight values (zero) are
provided.
Args:
keys_to_tensors: a dictionary from keys to tensors. Expected keys are:
'image/object/keypoint/x'
'image/object/keypoint/z/weights'
Returns:
A 2-D float tensor of shape [num_instances, num_keypoints] with values
representing the keypoint depth weights.
"""
x = keys_to_tensors['image/object/keypoint/x']
z = keys_to_tensors['image/object/keypoint/z/weights']
if isinstance(z, tf.SparseTensor):
z = tf.sparse_tensor_to_dense(z)
if isinstance(x, tf.SparseTensor):
x = tf.sparse_tensor_to_dense(x)
default_z = tf.zeros_like(x)
# Use keypoint depth weights if provided, otherwise use the default
# values.
z = tf.cond(tf.equal(tf.size(x), tf.size(z)),
true_fn=lambda: z,
false_fn=lambda: default_z)
z = tf.reshape(z, [-1, self._num_keypoints])
return z
def _reshape_keypoint_visibilities(self, keys_to_tensors):
"""Reshape keypoint visibilities.
......
......@@ -275,6 +275,124 @@ class TfExampleDecoderTest(test_case.TestCase):
self.assertAllEqual(expected_boxes,
tensor_dict[fields.InputDataFields.groundtruth_boxes])
def testDecodeKeypointDepth(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg, _ = self._create_encoded_and_decoded_data(
image_tensor, 'jpeg')
bbox_ymins = [0.0, 4.0]
bbox_xmins = [1.0, 5.0]
bbox_ymaxs = [2.0, 6.0]
bbox_xmaxs = [3.0, 7.0]
keypoint_ys = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
keypoint_xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
keypoint_visibility = [1, 2, 0, 1, 0, 2]
keypoint_depths = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
keypoint_depth_weights = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5]
def graph_fn():
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
dataset_util.bytes_feature(encoded_jpeg),
'image/format':
dataset_util.bytes_feature(six.b('jpeg')),
'image/object/bbox/ymin':
dataset_util.float_list_feature(bbox_ymins),
'image/object/bbox/xmin':
dataset_util.float_list_feature(bbox_xmins),
'image/object/bbox/ymax':
dataset_util.float_list_feature(bbox_ymaxs),
'image/object/bbox/xmax':
dataset_util.float_list_feature(bbox_xmaxs),
'image/object/keypoint/y':
dataset_util.float_list_feature(keypoint_ys),
'image/object/keypoint/x':
dataset_util.float_list_feature(keypoint_xs),
'image/object/keypoint/z':
dataset_util.float_list_feature(keypoint_depths),
'image/object/keypoint/z/weights':
dataset_util.float_list_feature(keypoint_depth_weights),
'image/object/keypoint/visibility':
dataset_util.int64_list_feature(keypoint_visibility),
})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder(
num_keypoints=3, load_keypoint_depth_features=True)
output = example_decoder.decode(tf.convert_to_tensor(example))
self.assertAllEqual(
(output[fields.InputDataFields.groundtruth_keypoint_depths].get_shape(
).as_list()), [2, 3])
self.assertAllEqual(
(output[fields.InputDataFields.groundtruth_keypoint_depth_weights]
.get_shape().as_list()), [2, 3])
return output
tensor_dict = self.execute_cpu(graph_fn, [])
expected_keypoint_depths = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
self.assertAllClose(
expected_keypoint_depths,
tensor_dict[fields.InputDataFields.groundtruth_keypoint_depths])
expected_keypoint_depth_weights = [[1.0, 0.9, 0.8], [0.7, 0.6, 0.5]]
self.assertAllClose(
expected_keypoint_depth_weights,
tensor_dict[fields.InputDataFields.groundtruth_keypoint_depth_weights])
def testDecodeKeypointDepthNoDepth(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg, _ = self._create_encoded_and_decoded_data(
image_tensor, 'jpeg')
bbox_ymins = [0.0, 4.0]
bbox_xmins = [1.0, 5.0]
bbox_ymaxs = [2.0, 6.0]
bbox_xmaxs = [3.0, 7.0]
keypoint_ys = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
keypoint_xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
keypoint_visibility = [1, 2, 0, 1, 0, 2]
def graph_fn():
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
dataset_util.bytes_feature(encoded_jpeg),
'image/format':
dataset_util.bytes_feature(six.b('jpeg')),
'image/object/bbox/ymin':
dataset_util.float_list_feature(bbox_ymins),
'image/object/bbox/xmin':
dataset_util.float_list_feature(bbox_xmins),
'image/object/bbox/ymax':
dataset_util.float_list_feature(bbox_ymaxs),
'image/object/bbox/xmax':
dataset_util.float_list_feature(bbox_xmaxs),
'image/object/keypoint/y':
dataset_util.float_list_feature(keypoint_ys),
'image/object/keypoint/x':
dataset_util.float_list_feature(keypoint_xs),
'image/object/keypoint/visibility':
dataset_util.int64_list_feature(keypoint_visibility),
})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder(
num_keypoints=3, load_keypoint_depth_features=True)
output = example_decoder.decode(tf.convert_to_tensor(example))
return output
tensor_dict = self.execute_cpu(graph_fn, [])
expected_keypoints_depth_default = [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
self.assertAllClose(
expected_keypoints_depth_default,
tensor_dict[fields.InputDataFields.groundtruth_keypoint_depths])
self.assertAllClose(
expected_keypoints_depth_default,
tensor_dict[fields.InputDataFields.groundtruth_keypoint_depth_weights])
def testDecodeKeypoint(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg, _ = self._create_encoded_and_decoded_data(
......
......@@ -30,7 +30,7 @@ enum InputType {
TF_SEQUENCE_EXAMPLE = 2; // TfSequenceExample Input
}
// Next id: 37
// Next id: 38
message InputReader {
// Name of input reader. Typically used to describe the dataset that is read
// by this input reader.
......@@ -134,6 +134,9 @@ message InputReader {
// Whether to load track information.
optional bool load_track_id = 33 [default = false];
// Whether to load keypoint depth features.
optional bool load_keypoint_depth_features = 37 [default = false];
// Whether to use the display name when decoding examples. This is only used
// when mapping class text strings to integers.
optional bool use_display_name = 17 [default = false];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment