Unverified Commit 4e92bc57 authored by Jonathan Huang's avatar Jonathan Huang Committed by GitHub
Browse files

Merge pull request #2639 from tombstone/data

Fixes #2634
parents 141ed951 ce417629
......@@ -14,6 +14,7 @@ py_library(
"//tensorflow",
"//tensorflow_models/object_detection/core:data_decoder",
"//tensorflow_models/object_detection/core:standard_fields",
"//tensorflow_models/object_detection/utils:label_map_util",
],
)
......
......@@ -22,6 +22,7 @@ import tensorflow as tf
from object_detection.core import data_decoder
from object_detection.core import standard_fields as fields
from object_detection.utils import label_map_util
slim_example_decoder = tf.contrib.slim.tfexample_decoder
......@@ -29,28 +30,59 @@ slim_example_decoder = tf.contrib.slim.tfexample_decoder
class TfExampleDecoder(data_decoder.DataDecoder):
"""Tensorflow Example proto decoder."""
def __init__(self):
"""Constructor sets keys_to_features and items_to_handlers."""
def __init__(self,
load_instance_masks=False,
label_map_proto_file=None,
use_display_name=False):
"""Constructor sets keys_to_features and items_to_handlers.
Args:
load_instance_masks: whether or not to load and handle instance masks.
label_map_proto_file: a file path to a
object_detection.protos.StringIntLabelMap proto. If provided, then the
mapped IDs of 'image/object/class/text' will take precedence over the
existing 'image/object/class/label' ID. Also, if provided, it is
assumed that 'image/object/class/text' will be in the data.
use_display_name: whether or not to use the `display_name` for label
mapping (instead of `name`). Only used if label_map_proto_file is
provided.
"""
self.keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/filename': tf.FixedLenFeature((), tf.string, default_value=''),
'image/key/sha256': tf.FixedLenFeature((), tf.string, default_value=''),
'image/source_id': tf.FixedLenFeature((), tf.string, default_value=''),
'image/height': tf.FixedLenFeature((), tf.int64, 1),
'image/width': tf.FixedLenFeature((), tf.int64, 1),
'image/encoded':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/format':
tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/filename':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/key/sha256':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/source_id':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/height':
tf.FixedLenFeature((), tf.int64, 1),
'image/width':
tf.FixedLenFeature((), tf.int64, 1),
# Object boxes and classes.
'image/object/bbox/xmin': tf.VarLenFeature(tf.float32),
'image/object/bbox/xmax': tf.VarLenFeature(tf.float32),
'image/object/bbox/ymin': tf.VarLenFeature(tf.float32),
'image/object/bbox/ymax': tf.VarLenFeature(tf.float32),
'image/object/class/label': tf.VarLenFeature(tf.int64),
'image/object/area': tf.VarLenFeature(tf.float32),
'image/object/is_crowd': tf.VarLenFeature(tf.int64),
'image/object/difficult': tf.VarLenFeature(tf.int64),
# Instance masks and classes.
'image/segmentation/object': tf.VarLenFeature(tf.int64),
'image/segmentation/object/class': tf.VarLenFeature(tf.int64)
'image/object/bbox/xmin':
tf.VarLenFeature(tf.float32),
'image/object/bbox/xmax':
tf.VarLenFeature(tf.float32),
'image/object/bbox/ymin':
tf.VarLenFeature(tf.float32),
'image/object/bbox/ymax':
tf.VarLenFeature(tf.float32),
'image/object/class/label':
tf.VarLenFeature(tf.int64),
'image/object/class/text':
tf.VarLenFeature(tf.string),
'image/object/area':
tf.VarLenFeature(tf.float32),
'image/object/is_crowd':
tf.VarLenFeature(tf.int64),
'image/object/difficult':
tf.VarLenFeature(tf.int64),
'image/object/group_of':
tf.VarLenFeature(tf.int64),
}
self.items_to_handlers = {
fields.InputDataFields.image: slim_example_decoder.Image(
......@@ -65,22 +97,42 @@ class TfExampleDecoder(data_decoder.DataDecoder):
fields.InputDataFields.groundtruth_boxes: (
slim_example_decoder.BoundingBox(
['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/')),
fields.InputDataFields.groundtruth_classes: (
slim_example_decoder.Tensor('image/object/class/label')),
fields.InputDataFields.groundtruth_area: slim_example_decoder.Tensor(
'image/object/area'),
fields.InputDataFields.groundtruth_is_crowd: (
slim_example_decoder.Tensor('image/object/is_crowd')),
fields.InputDataFields.groundtruth_difficult: (
slim_example_decoder.Tensor('image/object/difficult')),
# Instance masks and classes.
fields.InputDataFields.groundtruth_instance_masks: (
slim_example_decoder.ItemHandlerCallback(
['image/segmentation/object', 'image/height', 'image/width'],
self._reshape_instance_masks)),
fields.InputDataFields.groundtruth_instance_classes: (
slim_example_decoder.Tensor('image/segmentation/object/class')),
fields.InputDataFields.groundtruth_group_of: (
slim_example_decoder.Tensor('image/object/group_of'))
}
if load_instance_masks:
self.keys_to_features['image/object/mask'] = tf.VarLenFeature(tf.float32)
self.items_to_handlers[
fields.InputDataFields.groundtruth_instance_masks] = (
slim_example_decoder.ItemHandlerCallback(
['image/object/mask', 'image/height', 'image/width'],
self._reshape_instance_masks))
if label_map_proto_file:
label_map = label_map_util.get_label_map_dict(label_map_proto_file,
use_display_name)
# We use a default_value of -1, but we expect all labels to be contained
# in the label map.
table = tf.contrib.lookup.HashTable(
initializer=tf.contrib.lookup.KeyValueTensorInitializer(
keys=tf.constant(list(label_map.keys())),
values=tf.constant(list(label_map.values()), dtype=tf.int64)),
default_value=-1)
# If the label_map_proto is provided, try to use it in conjunction with
# the class text, and fall back to a materialized ID.
label_handler = slim_example_decoder.BackupHandler(
slim_example_decoder.LookupTensor(
'image/object/class/text', table, default_value=''),
slim_example_decoder.Tensor('image/object/class/label'))
else:
label_handler = slim_example_decoder.Tensor('image/object/class/label')
self.items_to_handlers[
fields.InputDataFields.groundtruth_classes] = label_handler
def decode(self, tf_example_string_tensor):
"""Decodes serialized tensorflow example and returns a tensor dictionary.
......@@ -106,14 +158,14 @@ class TfExampleDecoder(data_decoder.DataDecoder):
[None] containing containing object mask area in pixel squared.
fields.InputDataFields.groundtruth_is_crowd - 1D bool tensor of shape
[None] indicating if the boxes enclose a crowd.
Optional:
fields.InputDataFields.groundtruth_difficult - 1D bool tensor of shape
[None] indicating if the boxes represent `difficult` instances.
fields.InputDataFields.groundtruth_group_of - 1D bool tensor of shape
[None] indicating if the boxes represent `group_of` instances.
fields.InputDataFields.groundtruth_instance_masks - 3D int64 tensor of
shape [None, None, None] containing instance masks.
fields.InputDataFields.groundtruth_instance_classes - 1D int64 tensor
of shape [None] containing classes for the instance masks.
"""
serialized_example = tf.reshape(tf_example_string_tensor, shape=[])
decoder = slim_example_decoder.TFExampleDecoder(self.keys_to_features,
self.items_to_handlers)
......@@ -135,13 +187,14 @@ class TfExampleDecoder(data_decoder.DataDecoder):
keys_to_tensors: a dictionary from keys to tensors.
Returns:
A 3-D boolean tensor of shape [num_instances, height, width].
A 3-D float tensor of shape [num_instances, height, width] with values
in {0, 1}.
"""
masks = keys_to_tensors['image/segmentation/object']
if isinstance(masks, tf.SparseTensor):
masks = tf.sparse_tensor_to_dense(masks)
height = keys_to_tensors['image/height']
width = keys_to_tensors['image/width']
to_shape = tf.cast(tf.stack([-1, height, width]), tf.int32)
return tf.cast(tf.reshape(masks, to_shape), tf.bool)
masks = keys_to_tensors['image/object/mask']
if isinstance(masks, tf.SparseTensor):
masks = tf.sparse_tensor_to_dense(masks)
masks = tf.reshape(tf.to_float(tf.greater(masks, 0.0)), to_shape)
return tf.cast(masks, tf.float32)
......@@ -15,6 +15,7 @@
"""Tests for object_detection.data_decoders.tf_example_decoder."""
import os
import numpy as np
import tensorflow as tf
......@@ -51,6 +52,8 @@ class TfExampleDecoderTest(tf.test.TestCase):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _BytesFeature(self, value):
if isinstance(value, list):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def testDecodeJpegImage(self):
......@@ -165,6 +168,48 @@ class TfExampleDecoderTest(tf.test.TestCase):
self.assertAllEqual(bbox_classes,
tensor_dict[fields.InputDataFields.groundtruth_classes])
def testDecodeObjectLabelWithMapping(self):
image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor)
bbox_classes_text = ['cat', 'dog']
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
self._BytesFeature(encoded_jpeg),
'image/format':
self._BytesFeature('jpeg'),
'image/object/class/text':
self._BytesFeature(bbox_classes_text),
})).SerializeToString()
label_map_string = """
item {
id:3
name:'cat'
}
item {
id:1
name:'dog'
}
"""
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
with tf.gfile.Open(label_map_path, 'wb') as f:
f.write(label_map_string)
example_decoder = tf_example_decoder.TfExampleDecoder(
label_map_proto_file=label_map_path)
tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))
self.assertAllEqual((tensor_dict[fields.InputDataFields.groundtruth_classes]
.get_shape().as_list()), [None])
with self.test_session() as sess:
sess.run(tf.tables_initializer())
tensor_dict = sess.run(tensor_dict)
self.assertAllEqual([3, 1],
tensor_dict[fields.InputDataFields.groundtruth_classes])
def testDecodeObjectArea(self):
image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor)
......@@ -232,6 +277,30 @@ class TfExampleDecoderTest(tf.test.TestCase):
tensor_dict[
fields.InputDataFields.groundtruth_difficult])
def testDecodeObjectGroupOf(self):
image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor)
object_group_of = [0, 1]
example = tf.train.Example(features=tf.train.Features(
feature={
'image/encoded': self._BytesFeature(encoded_jpeg),
'image/format': self._BytesFeature('jpeg'),
'image/object/group_of': self._Int64Feature(object_group_of),
})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder()
tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))
self.assertAllEqual((tensor_dict[
fields.InputDataFields.groundtruth_group_of].get_shape().as_list()),
[None])
with self.test_session() as sess:
tensor_dict = sess.run(tensor_dict)
self.assertAllEqual(
[bool(item) for item in object_group_of],
tensor_dict[fields.InputDataFields.groundtruth_group_of])
def testDecodeInstanceSegmentation(self):
num_instances = 4
image_height = 5
......@@ -244,13 +313,14 @@ class TfExampleDecoderTest(tf.test.TestCase):
encoded_jpeg = self._EncodeImage(image_tensor)
# Randomly generate instance segmentation masks.
instance_segmentation = (
instance_masks = (
np.random.randint(2, size=(num_instances,
image_height,
image_width)).astype(np.int64))
image_width)).astype(np.float32))
instance_masks_flattened = np.reshape(instance_masks, [-1])
# Randomly generate class labels for each instance.
instance_segmentation_classes = np.random.randint(
object_classes = np.random.randint(
100, size=(num_instances)).astype(np.int64)
example = tf.train.Example(features=tf.train.Features(feature={
......@@ -258,11 +328,11 @@ class TfExampleDecoderTest(tf.test.TestCase):
'image/format': self._BytesFeature('jpeg'),
'image/height': self._Int64Feature([image_height]),
'image/width': self._Int64Feature([image_width]),
'image/segmentation/object': self._Int64Feature(
instance_segmentation.flatten()),
'image/segmentation/object/class': self._Int64Feature(
instance_segmentation_classes)})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder()
'image/object/mask': self._FloatFeature(instance_masks_flattened),
'image/object/class/label': self._Int64Feature(
object_classes)})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder(
load_instance_masks=True)
tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))
self.assertAllEqual((
......@@ -270,18 +340,52 @@ class TfExampleDecoderTest(tf.test.TestCase):
get_shape().as_list()), [None, None, None])
self.assertAllEqual((
tensor_dict[fields.InputDataFields.groundtruth_instance_classes].
tensor_dict[fields.InputDataFields.groundtruth_classes].
get_shape().as_list()), [None])
with self.test_session() as sess:
tensor_dict = sess.run(tensor_dict)
self.assertAllEqual(
instance_segmentation.astype(np.bool),
instance_masks.astype(np.float32),
tensor_dict[fields.InputDataFields.groundtruth_instance_masks])
self.assertAllEqual(
instance_segmentation_classes,
tensor_dict[fields.InputDataFields.groundtruth_instance_classes])
object_classes,
tensor_dict[fields.InputDataFields.groundtruth_classes])
def testInstancesNotAvailableByDefault(self):
num_instances = 4
image_height = 5
image_width = 3
# Randomly generate image.
image_tensor = np.random.randint(255, size=(image_height,
image_width,
3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor)
# Randomly generate instance segmentation masks.
instance_masks = (
np.random.randint(2, size=(num_instances,
image_height,
image_width)).astype(np.float32))
instance_masks_flattened = np.reshape(instance_masks, [-1])
# Randomly generate class labels for each instance.
object_classes = np.random.randint(
100, size=(num_instances)).astype(np.int64)
example = tf.train.Example(features=tf.train.Features(feature={
'image/encoded': self._BytesFeature(encoded_jpeg),
'image/format': self._BytesFeature('jpeg'),
'image/height': self._Int64Feature([image_height]),
'image/width': self._Int64Feature([image_width]),
'image/object/mask': self._FloatFeature(instance_masks_flattened),
'image/object/class/label': self._Int64Feature(
object_classes)})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder()
tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))
self.assertTrue(fields.InputDataFields.groundtruth_instance_masks
not in tensor_dict)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment