"...source/git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "9b5afcfe7af50edae22fbe0a745ebfd64a287d38"
Commit ce417629 authored by Vivek Rathod's avatar Vivek Rathod
Browse files

update tf example decoder and fix the breaking input reader builder tests.

parent 141ed951
......@@ -14,6 +14,7 @@ py_library(
"//tensorflow",
"//tensorflow_models/object_detection/core:data_decoder",
"//tensorflow_models/object_detection/core:standard_fields",
"//tensorflow_models/object_detection/utils:label_map_util",
],
)
......
......@@ -22,6 +22,7 @@ import tensorflow as tf
from object_detection.core import data_decoder
from object_detection.core import standard_fields as fields
from object_detection.utils import label_map_util
slim_example_decoder = tf.contrib.slim.tfexample_decoder
......@@ -29,28 +30,59 @@ slim_example_decoder = tf.contrib.slim.tfexample_decoder
class TfExampleDecoder(data_decoder.DataDecoder):
"""Tensorflow Example proto decoder."""
def __init__(self):
"""Constructor sets keys_to_features and items_to_handlers."""
def __init__(self,
load_instance_masks=False,
label_map_proto_file=None,
use_display_name=False):
"""Constructor sets keys_to_features and items_to_handlers.
Args:
load_instance_masks: whether or not to load and handle instance masks.
label_map_proto_file: a file path to a
object_detection.protos.StringIntLabelMap proto. If provided, then the
mapped IDs of 'image/object/class/text' will take precedence over the
existing 'image/object/class/label' ID. Also, if provided, it is
assumed that 'image/object/class/text' will be in the data.
use_display_name: whether or not to use the `display_name` for label
mapping (instead of `name`). Only used if label_map_proto_file is
provided.
"""
self.keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/filename': tf.FixedLenFeature((), tf.string, default_value=''),
'image/key/sha256': tf.FixedLenFeature((), tf.string, default_value=''),
'image/source_id': tf.FixedLenFeature((), tf.string, default_value=''),
'image/height': tf.FixedLenFeature((), tf.int64, 1),
'image/width': tf.FixedLenFeature((), tf.int64, 1),
'image/encoded':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/format':
tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/filename':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/key/sha256':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/source_id':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/height':
tf.FixedLenFeature((), tf.int64, 1),
'image/width':
tf.FixedLenFeature((), tf.int64, 1),
# Object boxes and classes.
'image/object/bbox/xmin': tf.VarLenFeature(tf.float32),
'image/object/bbox/xmax': tf.VarLenFeature(tf.float32),
'image/object/bbox/ymin': tf.VarLenFeature(tf.float32),
'image/object/bbox/ymax': tf.VarLenFeature(tf.float32),
'image/object/class/label': tf.VarLenFeature(tf.int64),
'image/object/area': tf.VarLenFeature(tf.float32),
'image/object/is_crowd': tf.VarLenFeature(tf.int64),
'image/object/difficult': tf.VarLenFeature(tf.int64),
# Instance masks and classes.
'image/segmentation/object': tf.VarLenFeature(tf.int64),
'image/segmentation/object/class': tf.VarLenFeature(tf.int64)
'image/object/bbox/xmin':
tf.VarLenFeature(tf.float32),
'image/object/bbox/xmax':
tf.VarLenFeature(tf.float32),
'image/object/bbox/ymin':
tf.VarLenFeature(tf.float32),
'image/object/bbox/ymax':
tf.VarLenFeature(tf.float32),
'image/object/class/label':
tf.VarLenFeature(tf.int64),
'image/object/class/text':
tf.VarLenFeature(tf.string),
'image/object/area':
tf.VarLenFeature(tf.float32),
'image/object/is_crowd':
tf.VarLenFeature(tf.int64),
'image/object/difficult':
tf.VarLenFeature(tf.int64),
'image/object/group_of':
tf.VarLenFeature(tf.int64),
}
self.items_to_handlers = {
fields.InputDataFields.image: slim_example_decoder.Image(
......@@ -65,22 +97,42 @@ class TfExampleDecoder(data_decoder.DataDecoder):
fields.InputDataFields.groundtruth_boxes: (
slim_example_decoder.BoundingBox(
['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/')),
fields.InputDataFields.groundtruth_classes: (
slim_example_decoder.Tensor('image/object/class/label')),
fields.InputDataFields.groundtruth_area: slim_example_decoder.Tensor(
'image/object/area'),
fields.InputDataFields.groundtruth_is_crowd: (
slim_example_decoder.Tensor('image/object/is_crowd')),
fields.InputDataFields.groundtruth_difficult: (
slim_example_decoder.Tensor('image/object/difficult')),
# Instance masks and classes.
fields.InputDataFields.groundtruth_instance_masks: (
slim_example_decoder.ItemHandlerCallback(
['image/segmentation/object', 'image/height', 'image/width'],
self._reshape_instance_masks)),
fields.InputDataFields.groundtruth_instance_classes: (
slim_example_decoder.Tensor('image/segmentation/object/class')),
fields.InputDataFields.groundtruth_group_of: (
slim_example_decoder.Tensor('image/object/group_of'))
}
if load_instance_masks:
self.keys_to_features['image/object/mask'] = tf.VarLenFeature(tf.float32)
self.items_to_handlers[
fields.InputDataFields.groundtruth_instance_masks] = (
slim_example_decoder.ItemHandlerCallback(
['image/object/mask', 'image/height', 'image/width'],
self._reshape_instance_masks))
if label_map_proto_file:
label_map = label_map_util.get_label_map_dict(label_map_proto_file,
use_display_name)
# We use a default_value of -1, but we expect all labels to be contained
# in the label map.
table = tf.contrib.lookup.HashTable(
initializer=tf.contrib.lookup.KeyValueTensorInitializer(
keys=tf.constant(list(label_map.keys())),
values=tf.constant(list(label_map.values()), dtype=tf.int64)),
default_value=-1)
# If the label_map_proto is provided, try to use it in conjunction with
# the class text, and fall back to a materialized ID.
label_handler = slim_example_decoder.BackupHandler(
slim_example_decoder.LookupTensor(
'image/object/class/text', table, default_value=''),
slim_example_decoder.Tensor('image/object/class/label'))
else:
label_handler = slim_example_decoder.Tensor('image/object/class/label')
self.items_to_handlers[
fields.InputDataFields.groundtruth_classes] = label_handler
def decode(self, tf_example_string_tensor):
"""Decodes serialized tensorflow example and returns a tensor dictionary.
......@@ -106,14 +158,14 @@ class TfExampleDecoder(data_decoder.DataDecoder):
[None] containing containing object mask area in pixel squared.
fields.InputDataFields.groundtruth_is_crowd - 1D bool tensor of shape
[None] indicating if the boxes enclose a crowd.
Optional:
fields.InputDataFields.groundtruth_difficult - 1D bool tensor of shape
[None] indicating if the boxes represent `difficult` instances.
fields.InputDataFields.groundtruth_group_of - 1D bool tensor of shape
[None] indicating if the boxes represent `group_of` instances.
fields.InputDataFields.groundtruth_instance_masks - 3D int64 tensor of
shape [None, None, None] containing instance masks.
fields.InputDataFields.groundtruth_instance_classes - 1D int64 tensor
of shape [None] containing classes for the instance masks.
"""
serialized_example = tf.reshape(tf_example_string_tensor, shape=[])
decoder = slim_example_decoder.TFExampleDecoder(self.keys_to_features,
self.items_to_handlers)
......@@ -135,13 +187,14 @@ class TfExampleDecoder(data_decoder.DataDecoder):
keys_to_tensors: a dictionary from keys to tensors.
Returns:
A 3-D boolean tensor of shape [num_instances, height, width].
A 3-D float tensor of shape [num_instances, height, width] with values
in {0, 1}.
"""
masks = keys_to_tensors['image/segmentation/object']
if isinstance(masks, tf.SparseTensor):
masks = tf.sparse_tensor_to_dense(masks)
height = keys_to_tensors['image/height']
width = keys_to_tensors['image/width']
to_shape = tf.cast(tf.stack([-1, height, width]), tf.int32)
return tf.cast(tf.reshape(masks, to_shape), tf.bool)
masks = keys_to_tensors['image/object/mask']
if isinstance(masks, tf.SparseTensor):
masks = tf.sparse_tensor_to_dense(masks)
masks = tf.reshape(tf.to_float(tf.greater(masks, 0.0)), to_shape)
return tf.cast(masks, tf.float32)
......@@ -15,6 +15,7 @@
"""Tests for object_detection.data_decoders.tf_example_decoder."""
import os
import numpy as np
import tensorflow as tf
......@@ -51,6 +52,8 @@ class TfExampleDecoderTest(tf.test.TestCase):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _BytesFeature(self, value):
if isinstance(value, list):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def testDecodeJpegImage(self):
......@@ -165,6 +168,48 @@ class TfExampleDecoderTest(tf.test.TestCase):
self.assertAllEqual(bbox_classes,
tensor_dict[fields.InputDataFields.groundtruth_classes])
def testDecodeObjectLabelWithMapping(self):
image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor)
bbox_classes_text = ['cat', 'dog']
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
self._BytesFeature(encoded_jpeg),
'image/format':
self._BytesFeature('jpeg'),
'image/object/class/text':
self._BytesFeature(bbox_classes_text),
})).SerializeToString()
label_map_string = """
item {
id:3
name:'cat'
}
item {
id:1
name:'dog'
}
"""
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
with tf.gfile.Open(label_map_path, 'wb') as f:
f.write(label_map_string)
example_decoder = tf_example_decoder.TfExampleDecoder(
label_map_proto_file=label_map_path)
tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))
self.assertAllEqual((tensor_dict[fields.InputDataFields.groundtruth_classes]
.get_shape().as_list()), [None])
with self.test_session() as sess:
sess.run(tf.tables_initializer())
tensor_dict = sess.run(tensor_dict)
self.assertAllEqual([3, 1],
tensor_dict[fields.InputDataFields.groundtruth_classes])
def testDecodeObjectArea(self):
image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor)
......@@ -232,6 +277,30 @@ class TfExampleDecoderTest(tf.test.TestCase):
tensor_dict[
fields.InputDataFields.groundtruth_difficult])
def testDecodeObjectGroupOf(self):
image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor)
object_group_of = [0, 1]
example = tf.train.Example(features=tf.train.Features(
feature={
'image/encoded': self._BytesFeature(encoded_jpeg),
'image/format': self._BytesFeature('jpeg'),
'image/object/group_of': self._Int64Feature(object_group_of),
})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder()
tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))
self.assertAllEqual((tensor_dict[
fields.InputDataFields.groundtruth_group_of].get_shape().as_list()),
[None])
with self.test_session() as sess:
tensor_dict = sess.run(tensor_dict)
self.assertAllEqual(
[bool(item) for item in object_group_of],
tensor_dict[fields.InputDataFields.groundtruth_group_of])
def testDecodeInstanceSegmentation(self):
num_instances = 4
image_height = 5
......@@ -244,13 +313,14 @@ class TfExampleDecoderTest(tf.test.TestCase):
encoded_jpeg = self._EncodeImage(image_tensor)
# Randomly generate instance segmentation masks.
instance_segmentation = (
instance_masks = (
np.random.randint(2, size=(num_instances,
image_height,
image_width)).astype(np.int64))
image_width)).astype(np.float32))
instance_masks_flattened = np.reshape(instance_masks, [-1])
# Randomly generate class labels for each instance.
instance_segmentation_classes = np.random.randint(
object_classes = np.random.randint(
100, size=(num_instances)).astype(np.int64)
example = tf.train.Example(features=tf.train.Features(feature={
......@@ -258,11 +328,11 @@ class TfExampleDecoderTest(tf.test.TestCase):
'image/format': self._BytesFeature('jpeg'),
'image/height': self._Int64Feature([image_height]),
'image/width': self._Int64Feature([image_width]),
'image/segmentation/object': self._Int64Feature(
instance_segmentation.flatten()),
'image/segmentation/object/class': self._Int64Feature(
instance_segmentation_classes)})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder()
'image/object/mask': self._FloatFeature(instance_masks_flattened),
'image/object/class/label': self._Int64Feature(
object_classes)})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder(
load_instance_masks=True)
tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))
self.assertAllEqual((
......@@ -270,18 +340,52 @@ class TfExampleDecoderTest(tf.test.TestCase):
get_shape().as_list()), [None, None, None])
self.assertAllEqual((
tensor_dict[fields.InputDataFields.groundtruth_instance_classes].
tensor_dict[fields.InputDataFields.groundtruth_classes].
get_shape().as_list()), [None])
with self.test_session() as sess:
tensor_dict = sess.run(tensor_dict)
self.assertAllEqual(
instance_segmentation.astype(np.bool),
instance_masks.astype(np.float32),
tensor_dict[fields.InputDataFields.groundtruth_instance_masks])
self.assertAllEqual(
instance_segmentation_classes,
tensor_dict[fields.InputDataFields.groundtruth_instance_classes])
object_classes,
tensor_dict[fields.InputDataFields.groundtruth_classes])
def testInstancesNotAvailableByDefault(self):
num_instances = 4
image_height = 5
image_width = 3
# Randomly generate image.
image_tensor = np.random.randint(255, size=(image_height,
image_width,
3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor)
# Randomly generate instance segmentation masks.
instance_masks = (
np.random.randint(2, size=(num_instances,
image_height,
image_width)).astype(np.float32))
instance_masks_flattened = np.reshape(instance_masks, [-1])
# Randomly generate class labels for each instance.
object_classes = np.random.randint(
100, size=(num_instances)).astype(np.int64)
example = tf.train.Example(features=tf.train.Features(feature={
'image/encoded': self._BytesFeature(encoded_jpeg),
'image/format': self._BytesFeature('jpeg'),
'image/height': self._Int64Feature([image_height]),
'image/width': self._Int64Feature([image_width]),
'image/object/mask': self._FloatFeature(instance_masks_flattened),
'image/object/class/label': self._Int64Feature(
object_classes)})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder()
tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))
self.assertTrue(fields.InputDataFields.groundtruth_instance_masks
not in tensor_dict)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment