Commit e06d2c3a authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Updated the tf_example_decoder such that it supports combining multiple

datasets which contain different subset of keypoints.

PiperOrigin-RevId: 400065449
parent b9c61118
...@@ -22,6 +22,7 @@ from __future__ import division ...@@ -22,6 +22,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import enum import enum
import functools
import numpy as np import numpy as np
from six.moves import zip from six.moves import zip
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
...@@ -42,6 +43,9 @@ except ImportError: ...@@ -42,6 +43,9 @@ except ImportError:
# pylint: enable=g-import-not-at-top # pylint: enable=g-import-not-at-top
_LABEL_OFFSET = 1 _LABEL_OFFSET = 1
# The field name of hosting keypoint text feature. Only used within this file
# to help forming the keypoint related features.
_KEYPOINT_TEXT_FIELD = 'image/object/keypoint/text'
class Visibility(enum.Enum): class Visibility(enum.Enum):
...@@ -140,7 +144,8 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -140,7 +144,8 @@ class TfExampleDecoder(data_decoder.DataDecoder):
expand_hierarchy_labels=False, expand_hierarchy_labels=False,
load_dense_pose=False, load_dense_pose=False,
load_track_id=False, load_track_id=False,
load_keypoint_depth_features=False): load_keypoint_depth_features=False,
use_keypoint_label_map=False):
"""Constructor sets keys_to_features and items_to_handlers. """Constructor sets keys_to_features and items_to_handlers.
Args: Args:
...@@ -177,6 +182,12 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -177,6 +182,12 @@ class TfExampleDecoder(data_decoder.DataDecoder):
including keypoint relative depths and weights. If this field is set to including keypoint relative depths and weights. If this field is set to
True but no keypoint depth features are in the input tf.Example, then True but no keypoint depth features are in the input tf.Example, then
default values will be populated. default values will be populated.
use_keypoint_label_map: If set to True, the 'image/object/keypoint/text'
field will be used to map the keypoint coordinates (using the label
map defined in label_map_proto_file) instead of assuming the ordering
in the tf.Example feature. This is useful when training with multiple
datasets while each of them contains different subset of keypoint
annotations.
Raises: Raises:
ValueError: If `instance_mask_type` option is not one of ValueError: If `instance_mask_type` option is not one of
...@@ -294,6 +305,34 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -294,6 +305,34 @@ class TfExampleDecoder(data_decoder.DataDecoder):
slim_example_decoder.Tensor('image/object/weight')), slim_example_decoder.Tensor('image/object/weight')),
} }
self._keypoint_label_map = None
if use_keypoint_label_map:
assert label_map_proto_file is not None
self._keypoint_label_map = label_map_util.get_keypoint_label_map_dict(
label_map_proto_file)
# We use a default_value of -1, but we expect all labels to be
# contained in the label map.
try:
# Dynamically try to load the tf v2 lookup, falling back to contrib
lookup = tf.compat.v2.lookup
hash_table_class = tf.compat.v2.lookup.StaticHashTable
except AttributeError:
lookup = contrib_lookup
hash_table_class = contrib_lookup.HashTable
self._kpts_name_to_id_table = hash_table_class(
initializer=lookup.KeyValueTensorInitializer(
keys=tf.constant(list(self._keypoint_label_map.keys())),
values=tf.constant(
list(self._keypoint_label_map.values()), dtype=tf.int64)),
default_value=-1)
self.keys_to_features[_KEYPOINT_TEXT_FIELD] = tf.VarLenFeature(
tf.string)
self.items_to_handlers[_KEYPOINT_TEXT_FIELD] = (
slim_example_decoder.ItemHandlerCallback(
[_KEYPOINT_TEXT_FIELD], self._keypoint_text_handle))
if load_multiclass_scores: if load_multiclass_scores:
self.keys_to_features[ self.keys_to_features[
'image/object/class/multiclass_scores'] = tf.VarLenFeature(tf.float32) 'image/object/class/multiclass_scores'] = tf.VarLenFeature(tf.float32)
...@@ -556,16 +595,70 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -556,16 +595,70 @@ class TfExampleDecoder(data_decoder.DataDecoder):
default_groundtruth_instance_mask_weights)) default_groundtruth_instance_mask_weights))
if fields.InputDataFields.groundtruth_keypoints in tensor_dict: if fields.InputDataFields.groundtruth_keypoints in tensor_dict:
# Set all keypoints that are not labeled to NaN.
gt_kpt_fld = fields.InputDataFields.groundtruth_keypoints gt_kpt_fld = fields.InputDataFields.groundtruth_keypoints
gt_kpt_vis_fld = fields.InputDataFields.groundtruth_keypoint_visibilities gt_kpt_vis_fld = fields.InputDataFields.groundtruth_keypoint_visibilities
visibilities_tiled = tf.tile(
tf.expand_dims(tensor_dict[gt_kpt_vis_fld], -1), if self._keypoint_label_map is None:
[1, 1, 2]) # Set all keypoints that are not labeled to NaN.
tensor_dict[gt_kpt_fld] = tf.where( tensor_dict[gt_kpt_fld] = tf.reshape(tensor_dict[gt_kpt_fld],
visibilities_tiled, [-1, self._num_keypoints, 2])
tensor_dict[gt_kpt_fld], tensor_dict[gt_kpt_vis_fld] = tf.reshape(
np.nan * tf.ones_like(tensor_dict[gt_kpt_fld])) tensor_dict[gt_kpt_vis_fld], [-1, self._num_keypoints])
visibilities_tiled = tf.tile(
tf.expand_dims(tensor_dict[gt_kpt_vis_fld], axis=-1), [1, 1, 2])
tensor_dict[gt_kpt_fld] = tf.where(
visibilities_tiled, tensor_dict[gt_kpt_fld],
np.nan * tf.ones_like(tensor_dict[gt_kpt_fld]))
else:
num_instances = tf.shape(tensor_dict['groundtruth_classes'])[0]
def true_fn(num_instances):
"""Logics to process the tensor when num_instances is not zero."""
kpts_idx = tf.cast(self._kpts_name_to_id_table.lookup(
tensor_dict[_KEYPOINT_TEXT_FIELD]), dtype=tf.int32)
num_kpt_texts = tf.cast(
tf.size(tensor_dict[_KEYPOINT_TEXT_FIELD]) / num_instances,
dtype=tf.int32)
# Prepare the index of the instances: [num_instances, num_kpt_texts].
instance_idx = tf.tile(
tf.expand_dims(tf.range(num_instances, dtype=tf.int32), axis=-1),
[1, num_kpt_texts])
# Prepare the index of the keypoints to scatter the keypoint
# coordinates: [num_kpts_texts * num_instances, 2].
kpt_idx = tf.concat([
tf.reshape(
instance_idx, shape=[num_kpt_texts * num_instances, 1]),
tf.expand_dims(kpts_idx, axis=-1)
], axis=1)
gt_kpt = tf.scatter_nd(
kpt_idx,
tensor_dict[gt_kpt_fld],
shape=[num_instances, self._num_keypoints, 2])
gt_kpt_vis = tf.cast(tf.scatter_nd(
kpt_idx,
tensor_dict[gt_kpt_vis_fld],
shape=[num_instances, self._num_keypoints]), dtype=tf.bool)
visibilities_tiled = tf.tile(
tf.expand_dims(gt_kpt_vis, axis=-1), [1, 1, 2])
gt_kpt = tf.where(visibilities_tiled, gt_kpt,
np.nan * tf.ones_like(gt_kpt))
return (gt_kpt, gt_kpt_vis)
def false_fn():
"""Logics to process the tensor when num_instances is zero."""
return (tf.zeros([0, self._num_keypoints, 2], dtype=tf.float32),
tf.zeros([0, self._num_keypoints], dtype=tf.bool))
true_fn = functools.partial(true_fn, num_instances)
results = tf.cond(num_instances > 0, true_fn, false_fn)
tensor_dict[gt_kpt_fld] = results[0]
tensor_dict[gt_kpt_vis_fld] = results[1]
# Since the keypoint text tensor won't be used anymore, deleting it from
# the tensor_dict to avoid further code changes to handle it in the
# inputs.py file.
del tensor_dict[_KEYPOINT_TEXT_FIELD]
if self._expand_hierarchy_labels: if self._expand_hierarchy_labels:
input_fields = fields.InputDataFields input_fields = fields.InputDataFields
...@@ -622,6 +715,13 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -622,6 +715,13 @@ class TfExampleDecoder(data_decoder.DataDecoder):
return tensor_dict return tensor_dict
def _keypoint_text_handle(self, keys_to_tensors):
"""Reshapes keypoint text feature."""
y = keys_to_tensors[_KEYPOINT_TEXT_FIELD]
if isinstance(y, tf.SparseTensor):
y = tf.sparse_tensor_to_dense(y)
return y
def _reshape_keypoints(self, keys_to_tensors): def _reshape_keypoints(self, keys_to_tensors):
"""Reshape keypoints. """Reshape keypoints.
...@@ -633,7 +733,7 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -633,7 +733,7 @@ class TfExampleDecoder(data_decoder.DataDecoder):
'image/object/keypoint/y' 'image/object/keypoint/y'
Returns: Returns:
A 3-D float tensor of shape [num_instances, num_keypoints, 2] with values A 2-D float tensor of shape [num_instances * num_keypoints, 2] with values
in [0, 1]. in [0, 1].
""" """
y = keys_to_tensors['image/object/keypoint/y'] y = keys_to_tensors['image/object/keypoint/y']
...@@ -645,7 +745,6 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -645,7 +745,6 @@ class TfExampleDecoder(data_decoder.DataDecoder):
x = tf.sparse_tensor_to_dense(x) x = tf.sparse_tensor_to_dense(x)
x = tf.expand_dims(x, 1) x = tf.expand_dims(x, 1)
keypoints = tf.concat([y, x], 1) keypoints = tf.concat([y, x], 1)
keypoints = tf.reshape(keypoints, [-1, self._num_keypoints, 2])
return keypoints return keypoints
def _reshape_keypoint_depths(self, keys_to_tensors): def _reshape_keypoint_depths(self, keys_to_tensors):
...@@ -739,7 +838,7 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -739,7 +838,7 @@ class TfExampleDecoder(data_decoder.DataDecoder):
'image/object/keypoint/visibility' 'image/object/keypoint/visibility'
Returns: Returns:
A 2-D bool tensor of shape [num_instances, num_keypoints] with values A 1-D bool tensor of shape [num_instances * num_keypoints] with values
in {0, 1}. 1 if the keypoint is labeled, 0 otherwise. in {0, 1}. 1 if the keypoint is labeled, 0 otherwise.
""" """
x = keys_to_tensors['image/object/keypoint/x'] x = keys_to_tensors['image/object/keypoint/x']
...@@ -760,7 +859,6 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -760,7 +859,6 @@ class TfExampleDecoder(data_decoder.DataDecoder):
vis = tf.math.logical_or( vis = tf.math.logical_or(
tf.math.equal(vis, Visibility.NOT_VISIBLE.value), tf.math.equal(vis, Visibility.NOT_VISIBLE.value),
tf.math.equal(vis, Visibility.VISIBLE.value)) tf.math.equal(vis, Visibility.VISIBLE.value))
vis = tf.reshape(vis, [-1, self._num_keypoints])
return vis return vis
def _reshape_instance_masks(self, keys_to_tensors): def _reshape_instance_masks(self, keys_to_tensors):
......
...@@ -459,6 +459,167 @@ class TfExampleDecoderTest(test_case.TestCase): ...@@ -459,6 +459,167 @@ class TfExampleDecoderTest(test_case.TestCase):
expected_visibility, expected_visibility,
tensor_dict[fields.InputDataFields.groundtruth_keypoint_visibilities]) tensor_dict[fields.InputDataFields.groundtruth_keypoint_visibilities])
def testDecodeKeypointNoInstance(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg, _ = self._create_encoded_and_decoded_data(
image_tensor, 'jpeg')
bbox_ymins = []
bbox_xmins = []
bbox_ymaxs = []
bbox_xmaxs = []
keypoint_ys = []
keypoint_xs = []
keypoint_visibility = []
def graph_fn():
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
dataset_util.bytes_feature(encoded_jpeg),
'image/format':
dataset_util.bytes_feature(six.b('jpeg')),
'image/object/bbox/ymin':
dataset_util.float_list_feature(bbox_ymins),
'image/object/bbox/xmin':
dataset_util.float_list_feature(bbox_xmins),
'image/object/bbox/ymax':
dataset_util.float_list_feature(bbox_ymaxs),
'image/object/bbox/xmax':
dataset_util.float_list_feature(bbox_xmaxs),
'image/object/keypoint/y':
dataset_util.float_list_feature(keypoint_ys),
'image/object/keypoint/x':
dataset_util.float_list_feature(keypoint_xs),
'image/object/keypoint/visibility':
dataset_util.int64_list_feature(keypoint_visibility),
})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder(num_keypoints=3)
output = example_decoder.decode(tf.convert_to_tensor(example))
self.assertAllEqual((output[
fields.InputDataFields.groundtruth_boxes].get_shape().as_list()),
[None, 4])
self.assertAllEqual((output[
fields.InputDataFields.groundtruth_keypoints].get_shape().as_list()),
[0, 3, 2])
return output
tensor_dict = self.execute_cpu(graph_fn, [])
self.assertAllEqual(
[0, 4], tensor_dict[fields.InputDataFields.groundtruth_boxes].shape)
self.assertAllEqual(
[0, 3, 2],
tensor_dict[fields.InputDataFields.groundtruth_keypoints].shape)
def testDecodeKeypointWithText(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg, _ = self._create_encoded_and_decoded_data(
image_tensor, 'jpeg')
bbox_classes = [0, 1]
bbox_ymins = [0.0, 4.0]
bbox_xmins = [1.0, 5.0]
bbox_ymaxs = [2.0, 6.0]
bbox_xmaxs = [3.0, 7.0]
keypoint_ys = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
keypoint_xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
keypoint_visibility = [1, 2, 0, 1, 0, 2]
keypoint_texts = [
six.b('nose'), six.b('left_eye'), six.b('right_eye'), six.b('nose'),
six.b('left_eye'), six.b('right_eye')
]
label_map_string = """
item: {
id: 1
name: 'face'
display_name: 'face'
keypoints {
id: 0
label: "nose"
}
keypoints {
id: 2
label: "right_eye"
}
}
item: {
id: 2
name: 'person'
display_name: 'person'
keypoints {
id: 1
label: "left_eye"
}
}
"""
label_map_proto_file = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
with tf.gfile.Open(label_map_proto_file, 'wb') as f:
f.write(label_map_string)
def graph_fn():
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
dataset_util.bytes_feature(encoded_jpeg),
'image/format':
dataset_util.bytes_feature(six.b('jpeg')),
'image/object/bbox/ymin':
dataset_util.float_list_feature(bbox_ymins),
'image/object/bbox/xmin':
dataset_util.float_list_feature(bbox_xmins),
'image/object/bbox/ymax':
dataset_util.float_list_feature(bbox_ymaxs),
'image/object/bbox/xmax':
dataset_util.float_list_feature(bbox_xmaxs),
'image/object/keypoint/y':
dataset_util.float_list_feature(keypoint_ys),
'image/object/keypoint/x':
dataset_util.float_list_feature(keypoint_xs),
'image/object/keypoint/visibility':
dataset_util.int64_list_feature(keypoint_visibility),
'image/object/keypoint/text':
dataset_util.bytes_list_feature(keypoint_texts),
'image/object/class/label':
dataset_util.int64_list_feature(bbox_classes),
})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder(
label_map_proto_file=label_map_proto_file, num_keypoints=5,
use_keypoint_label_map=True)
output = example_decoder.decode(tf.convert_to_tensor(example))
self.assertAllEqual((output[
fields.InputDataFields.groundtruth_boxes].get_shape().as_list()),
[None, 4])
self.assertAllEqual((output[
fields.InputDataFields.groundtruth_keypoints].get_shape().as_list()),
[None, 5, 2])
return output
output = self.execute_cpu(graph_fn, [])
expected_boxes = np.vstack([bbox_ymins, bbox_xmins, bbox_ymaxs,
bbox_xmaxs]).transpose()
self.assertAllEqual(expected_boxes,
output[fields.InputDataFields.groundtruth_boxes])
expected_keypoints = [[[0.0, 1.0], [1.0, 2.0], [np.nan, np.nan],
[np.nan, np.nan], [np.nan, np.nan]],
[[3.0, 4.0], [np.nan, np.nan], [5.0, 6.0],
[np.nan, np.nan], [np.nan, np.nan]]]
self.assertAllClose(expected_keypoints,
output[fields.InputDataFields.groundtruth_keypoints])
expected_visibility = (
(np.array(keypoint_visibility) > 0).reshape((2, 3)))
gt_kpts_vis_fld = fields.InputDataFields.groundtruth_keypoint_visibilities
self.assertAllEqual(expected_visibility, output[gt_kpts_vis_fld][:, 0:3])
# The additional keypoints should all have False visibility.
self.assertAllEqual(
np.zeros([2, 2], dtype=np.bool), output[gt_kpts_vis_fld][:, 3:])
def testDecodeKeypointNoVisibilities(self): def testDecodeKeypointNoVisibilities(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8) image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg, _ = self._create_encoded_and_decoded_data( encoded_jpeg, _ = self._create_encoded_and_decoded_data(
...@@ -735,6 +896,18 @@ class TfExampleDecoderTest(test_case.TestCase): ...@@ -735,6 +896,18 @@ class TfExampleDecoderTest(test_case.TestCase):
item { item {
id:1 id:1
name:'cat' name:'cat'
keypoints {
id: 0
label: "nose"
}
keypoints {
id: 1
label: "left_eye"
}
keypoints {
id: 2
label: "right_eye"
}
} }
item { item {
id:2 id:2
......
...@@ -30,7 +30,7 @@ enum InputType { ...@@ -30,7 +30,7 @@ enum InputType {
TF_SEQUENCE_EXAMPLE = 2; // TfSequenceExample Input TF_SEQUENCE_EXAMPLE = 2; // TfSequenceExample Input
} }
// Next id: 38 // Next id: 39
message InputReader { message InputReader {
// Name of input reader. Typically used to describe the dataset that is read // Name of input reader. Typically used to describe the dataset that is read
// by this input reader. // by this input reader.
...@@ -151,6 +151,11 @@ message InputReader { ...@@ -151,6 +151,11 @@ message InputReader {
// random choice. // random choice.
optional int32 frame_index = 32 [default = -1]; optional int32 frame_index = 32 [default = -1];
// Whether to use the label map and the keypoint text feature to construct the
// keypoint coordinates/visibilities groundtruth tensors. Usually used when
// training with multiple datasets that contain different subset of keypoints.
optional bool use_keypoint_label_map = 38 [default = false];
oneof input_reader { oneof input_reader {
TFRecordInputReader tf_record_input_reader = 8; TFRecordInputReader tf_record_input_reader = 8;
ExternalInputReader external_input_reader = 9; ExternalInputReader external_input_reader = 9;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment